rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::json_utils::merge;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::{StreamingCompletionModel, StreamingResult};
17use crate::{
18    agent::AgentBuilder,
19    completion::{self, CompletionError, CompletionRequest},
20    extractor::ExtractorBuilder,
21    json_utils, message, OneOrMany,
22};
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::{json, Value};
26
27// ================================================================
28// Main Galadriel Client
29// ================================================================
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
31
32#[derive(Clone)]
33pub struct Client {
34    base_url: String,
35    http_client: reqwest::Client,
36}
37
38impl Client {
39    /// Create a new Galadriel client with the given API key and optional fine-tune API key.
40    pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
41        Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
42    }
43
44    /// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
45    pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
46        Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
47    }
48
49    pub fn from_url_with_optional_key(
50        api_key: &str,
51        base_url: &str,
52        fine_tune_api_key: Option<&str>,
53    ) -> Self {
54        Self {
55            base_url: base_url.to_string(),
56            http_client: reqwest::Client::builder()
57                .default_headers({
58                    let mut headers = reqwest::header::HeaderMap::new();
59                    headers.insert(
60                        "Authorization",
61                        format!("Bearer {}", api_key)
62                            .parse()
63                            .expect("Bearer token should parse"),
64                    );
65                    if let Some(key) = fine_tune_api_key {
66                        headers.insert(
67                            "Fine-Tune-Authorization",
68                            format!("Bearer {}", key)
69                                .parse()
70                                .expect("Bearer token should parse"),
71                        );
72                    }
73                    headers
74                })
75                .build()
76                .expect("Galadriel reqwest client should build"),
77        }
78    }
79
80    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
81    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
82    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
83    pub fn from_env() -> Self {
84        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
85        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
86        Self::new(&api_key, fine_tune_api_key.as_deref())
87    }
88    fn post(&self, path: &str) -> reqwest::RequestBuilder {
89        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
90        self.http_client.post(url)
91    }
92
93    /// Create a completion model with the given name.
94    ///
95    /// # Example
96    /// ```
97    /// use rig::providers::galadriel::{Client, self};
98    ///
99    /// // Initialize the Galadriel client
100    /// let galadriel = Client::new("your-galadriel-api-key", None);
101    ///
102    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
103    /// ```
104    pub fn completion_model(&self, model: &str) -> CompletionModel {
105        CompletionModel::new(self.clone(), model)
106    }
107
108    /// Create an agent builder with the given completion model.
109    ///
110    /// # Example
111    /// ```
112    /// use rig::providers::galadriel::{Client, self};
113    ///
114    /// // Initialize the Galadriel client
115    /// let galadriel = Client::new("your-galadriel-api-key", None);
116    ///
117    /// let agent = galadriel.agent(galadriel::GPT_4)
118    ///    .preamble("You are comedian AI with a mission to make people laugh.")
119    ///    .temperature(0.0)
120    ///    .build();
121    /// ```
122    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123        AgentBuilder::new(self.completion_model(model))
124    }
125
126    /// Create an extractor builder with the given completion model.
127    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128        &self,
129        model: &str,
130    ) -> ExtractorBuilder<T, CompletionModel> {
131        ExtractorBuilder::new(self.completion_model(model))
132    }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137    message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143    Ok(T),
144    Err(ApiErrorResponse),
145}
146
147#[derive(Clone, Debug, Deserialize)]
148pub struct Usage {
149    pub prompt_tokens: usize,
150    pub total_tokens: usize,
151}
152
153impl std::fmt::Display for Usage {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        write!(
156            f,
157            "Prompt tokens: {} Total tokens: {}",
158            self.prompt_tokens, self.total_tokens
159        )
160    }
161}
162
163// ================================================================
164// Galadriel Completion API
165// ================================================================
166/// `o1-preview` completion model
167pub const O1_PREVIEW: &str = "o1-preview";
168/// `o1-preview-2024-09-12` completion model
169pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
170/// `o1-mini completion model
171pub const O1_MINI: &str = "o1-mini";
172/// `o1-mini-2024-09-12` completion model
173pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
174/// `gpt-4o` completion model
175pub const GPT_4O: &str = "gpt-4o";
176/// `gpt-4o-2024-05-13` completion model
177pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
178/// `gpt-4-turbo` completion model
179pub const GPT_4_TURBO: &str = "gpt-4-turbo";
180/// `gpt-4-turbo-2024-04-09` completion model
181pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
182/// `gpt-4-turbo-preview` completion model
183pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
184/// `gpt-4-0125-preview` completion model
185pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
186/// `gpt-4-1106-preview` completion model
187pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
188/// `gpt-4-vision-preview` completion model
189pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
190/// `gpt-4-1106-vision-preview` completion model
191pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
192/// `gpt-4` completion model
193pub const GPT_4: &str = "gpt-4";
194/// `gpt-4-0613` completion model
195pub const GPT_4_0613: &str = "gpt-4-0613";
196/// `gpt-4-32k` completion model
197pub const GPT_4_32K: &str = "gpt-4-32k";
198/// `gpt-4-32k-0613` completion model
199pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
200/// `gpt-3.5-turbo` completion model
201pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
202/// `gpt-3.5-turbo-0125` completion model
203pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
204/// `gpt-3.5-turbo-1106` completion model
205pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
206/// `gpt-3.5-turbo-instruct` completion model
207pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
208
209#[derive(Debug, Deserialize)]
210pub struct CompletionResponse {
211    pub id: String,
212    pub object: String,
213    pub created: u64,
214    pub model: String,
215    pub system_fingerprint: Option<String>,
216    pub choices: Vec<Choice>,
217    pub usage: Option<Usage>,
218}
219
220impl From<ApiErrorResponse> for CompletionError {
221    fn from(err: ApiErrorResponse) -> Self {
222        CompletionError::ProviderError(err.message)
223    }
224}
225
226impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
227    type Error = CompletionError;
228
229    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
230        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
231            CompletionError::ResponseError("Response contained no choices".to_owned())
232        })?;
233
234        let mut content = message
235            .content
236            .as_ref()
237            .map(|c| vec![completion::AssistantContent::text(c)])
238            .unwrap_or_default();
239
240        content.extend(message.tool_calls.iter().map(|call| {
241            completion::AssistantContent::tool_call(
242                &call.function.name,
243                &call.function.name,
244                call.function.arguments.clone(),
245            )
246        }));
247
248        let choice = OneOrMany::many(content).map_err(|_| {
249            CompletionError::ResponseError(
250                "Response contained no message or tool call (empty)".to_owned(),
251            )
252        })?;
253
254        Ok(completion::CompletionResponse {
255            choice,
256            raw_response: response,
257        })
258    }
259}
260
261#[derive(Debug, Deserialize)]
262pub struct Choice {
263    pub index: usize,
264    pub message: Message,
265    pub logprobs: Option<serde_json::Value>,
266    pub finish_reason: String,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270pub struct Message {
271    pub role: String,
272    pub content: Option<String>,
273    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
274    pub tool_calls: Vec<openai::ToolCall>,
275}
276
277impl TryFrom<Message> for message::Message {
278    type Error = message::MessageError;
279
280    fn try_from(message: Message) -> Result<Self, Self::Error> {
281        let tool_calls: Vec<message::ToolCall> = message
282            .tool_calls
283            .into_iter()
284            .map(|tool_call| tool_call.into())
285            .collect();
286
287        match message.role.as_str() {
288            "user" => Ok(Self::User {
289                content: OneOrMany::one(
290                    message
291                        .content
292                        .map(|content| message::UserContent::text(&content))
293                        .ok_or_else(|| {
294                            message::MessageError::ConversionError("Empty user message".to_string())
295                        })?,
296                ),
297            }),
298            "assistant" => Ok(Self::Assistant {
299                content: OneOrMany::many(
300                    tool_calls
301                        .into_iter()
302                        .map(message::AssistantContent::ToolCall)
303                        .chain(
304                            message
305                                .content
306                                .map(|content| message::AssistantContent::text(&content))
307                                .into_iter(),
308                        ),
309                )
310                .map_err(|_| {
311                    message::MessageError::ConversionError("Empty assistant message".to_string())
312                })?,
313            }),
314            _ => Err(message::MessageError::ConversionError(format!(
315                "Unknown role: {}",
316                message.role
317            ))),
318        }
319    }
320}
321
322impl TryFrom<message::Message> for Message {
323    type Error = message::MessageError;
324
325    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
326        match message {
327            message::Message::User { content } => Ok(Self {
328                role: "user".to_string(),
329                content: content.iter().find_map(|c| match c {
330                    message::UserContent::Text(text) => Some(text.text.clone()),
331                    _ => None,
332                }),
333                tool_calls: vec![],
334            }),
335            message::Message::Assistant { content } => {
336                let mut text_content: Option<String> = None;
337                let mut tool_calls = vec![];
338
339                for c in content.iter() {
340                    match c {
341                        message::AssistantContent::Text(text) => {
342                            text_content = Some(
343                                text_content
344                                    .map(|mut existing| {
345                                        existing.push('\n');
346                                        existing.push_str(&text.text);
347                                        existing
348                                    })
349                                    .unwrap_or_else(|| text.text.clone()),
350                            );
351                        }
352                        message::AssistantContent::ToolCall(tool_call) => {
353                            tool_calls.push(tool_call.clone().into());
354                        }
355                    }
356                }
357
358                Ok(Self {
359                    role: "assistant".to_string(),
360                    content: text_content,
361                    tool_calls,
362                })
363            }
364        }
365    }
366}
367
368#[derive(Clone, Debug, Deserialize, Serialize)]
369pub struct ToolDefinition {
370    pub r#type: String,
371    pub function: completion::ToolDefinition,
372}
373
374impl From<completion::ToolDefinition> for ToolDefinition {
375    fn from(tool: completion::ToolDefinition) -> Self {
376        Self {
377            r#type: "function".into(),
378            function: tool,
379        }
380    }
381}
382
383#[derive(Debug, Deserialize)]
384pub struct Function {
385    pub name: String,
386    pub arguments: String,
387}
388
389#[derive(Clone)]
390pub struct CompletionModel {
391    client: Client,
392    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
393    pub model: String,
394}
395
396impl CompletionModel {
397    pub(crate) fn create_completion_request(
398        &self,
399        completion_request: CompletionRequest,
400    ) -> Result<Value, CompletionError> {
401        // Build up the order of messages (context, chat_history, prompt)
402        let mut partial_history = vec![];
403        if let Some(docs) = completion_request.normalized_documents() {
404            partial_history.push(docs);
405        }
406        partial_history.extend(completion_request.chat_history);
407
408        // Add preamble to chat history (if available)
409        let mut full_history: Vec<Message> = match &completion_request.preamble {
410            Some(preamble) => vec![Message {
411                role: "system".to_string(),
412                content: Some(preamble.to_string()),
413                tool_calls: vec![],
414            }],
415            None => vec![],
416        };
417
418        // Convert and extend the rest of the history
419        full_history.extend(
420            partial_history
421                .into_iter()
422                .map(message::Message::try_into)
423                .collect::<Result<Vec<Message>, _>>()?,
424        );
425
426        let request = if completion_request.tools.is_empty() {
427            json!({
428                "model": self.model,
429                "messages": full_history,
430                "temperature": completion_request.temperature,
431            })
432        } else {
433            json!({
434                "model": self.model,
435                "messages": full_history,
436                "temperature": completion_request.temperature,
437                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
438                "tool_choice": "auto",
439            })
440        };
441
442        let request = if let Some(params) = completion_request.additional_params {
443            json_utils::merge(request, params)
444        } else {
445            request
446        };
447
448        Ok(request)
449    }
450}
451
452impl CompletionModel {
453    pub fn new(client: Client, model: &str) -> Self {
454        Self {
455            client,
456            model: model.to_string(),
457        }
458    }
459}
460
461impl completion::CompletionModel for CompletionModel {
462    type Response = CompletionResponse;
463
464    #[cfg_attr(feature = "worker", worker::send)]
465    async fn completion(
466        &self,
467        completion_request: CompletionRequest,
468    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
469        let request = self.create_completion_request(completion_request)?;
470
471        let response = self
472            .client
473            .post("/chat/completions")
474            .json(&request)
475            .send()
476            .await?;
477
478        if response.status().is_success() {
479            let t = response.text().await?;
480            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
481
482            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
483                ApiResponse::Ok(response) => {
484                    tracing::info!(target: "rig",
485                        "Galadriel completion token usage: {:?}",
486                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
487                    );
488                    response.try_into()
489                }
490                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
491            }
492        } else {
493            Err(CompletionError::ProviderError(response.text().await?))
494        }
495    }
496}
497
498impl StreamingCompletionModel for CompletionModel {
499    async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
500        let mut request = self.create_completion_request(request)?;
501
502        request = merge(request, json!({"stream": true}));
503
504        let builder = self.client.post("/chat/completions").json(&request);
505
506        send_compatible_streaming_request(builder).await
507    }
508}