rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError, CompletionRequest},
16    extractor::ExtractorBuilder,
17    json_utils, message, OneOrMany,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23use super::openai;
24
25// ================================================================
26// Main Galadriel Client
27// ================================================================
28const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
29
30#[derive(Clone)]
31pub struct Client {
32    base_url: String,
33    http_client: reqwest::Client,
34}
35
36impl Client {
37    /// Create a new Galadriel client with the given API key and optional fine-tune API key.
38    pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
39        Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
40    }
41
42    /// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
43    pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
44        Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
45    }
46
47    pub fn from_url_with_optional_key(
48        api_key: &str,
49        base_url: &str,
50        fine_tune_api_key: Option<&str>,
51    ) -> Self {
52        Self {
53            base_url: base_url.to_string(),
54            http_client: reqwest::Client::builder()
55                .default_headers({
56                    let mut headers = reqwest::header::HeaderMap::new();
57                    headers.insert(
58                        "Authorization",
59                        format!("Bearer {}", api_key)
60                            .parse()
61                            .expect("Bearer token should parse"),
62                    );
63                    if let Some(key) = fine_tune_api_key {
64                        headers.insert(
65                            "Fine-Tune-Authorization",
66                            format!("Bearer {}", key)
67                                .parse()
68                                .expect("Bearer token should parse"),
69                        );
70                    }
71                    headers
72                })
73                .build()
74                .expect("Galadriel reqwest client should build"),
75        }
76    }
77
78    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
79    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
80    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
81    pub fn from_env() -> Self {
82        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
83        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
84        Self::new(&api_key, fine_tune_api_key.as_deref())
85    }
86    fn post(&self, path: &str) -> reqwest::RequestBuilder {
87        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
88        self.http_client.post(url)
89    }
90
91    /// Create a completion model with the given name.
92    ///
93    /// # Example
94    /// ```
95    /// use rig::providers::galadriel::{Client, self};
96    ///
97    /// // Initialize the Galadriel client
98    /// let galadriel = Client::new("your-galadriel-api-key", None);
99    ///
100    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
101    /// ```
102    pub fn completion_model(&self, model: &str) -> CompletionModel {
103        CompletionModel::new(self.clone(), model)
104    }
105
106    /// Create an agent builder with the given completion model.
107    ///
108    /// # Example
109    /// ```
110    /// use rig::providers::galadriel::{Client, self};
111    ///
112    /// // Initialize the Galadriel client
113    /// let galadriel = Client::new("your-galadriel-api-key", None);
114    ///
115    /// let agent = galadriel.agent(galadriel::GPT_4)
116    ///    .preamble("You are comedian AI with a mission to make people laugh.")
117    ///    .temperature(0.0)
118    ///    .build();
119    /// ```
120    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
121        AgentBuilder::new(self.completion_model(model))
122    }
123
124    /// Create an extractor builder with the given completion model.
125    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
126        &self,
127        model: &str,
128    ) -> ExtractorBuilder<T, CompletionModel> {
129        ExtractorBuilder::new(self.completion_model(model))
130    }
131}
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135    message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141    Ok(T),
142    Err(ApiErrorResponse),
143}
144
145#[derive(Clone, Debug, Deserialize)]
146pub struct Usage {
147    pub prompt_tokens: usize,
148    pub total_tokens: usize,
149}
150
151impl std::fmt::Display for Usage {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(
154            f,
155            "Prompt tokens: {} Total tokens: {}",
156            self.prompt_tokens, self.total_tokens
157        )
158    }
159}
160
161// ================================================================
162// Galadriel Completion API
163// ================================================================
164/// `o1-preview` completion model
165pub const O1_PREVIEW: &str = "o1-preview";
166/// `o1-preview-2024-09-12` completion model
167pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
168/// `o1-mini completion model
169pub const O1_MINI: &str = "o1-mini";
170/// `o1-mini-2024-09-12` completion model
171pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
172/// `gpt-4o` completion model
173pub const GPT_4O: &str = "gpt-4o";
174/// `gpt-4o-2024-05-13` completion model
175pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
176/// `gpt-4-turbo` completion model
177pub const GPT_4_TURBO: &str = "gpt-4-turbo";
178/// `gpt-4-turbo-2024-04-09` completion model
179pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
180/// `gpt-4-turbo-preview` completion model
181pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
182/// `gpt-4-0125-preview` completion model
183pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
184/// `gpt-4-1106-preview` completion model
185pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
186/// `gpt-4-vision-preview` completion model
187pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
188/// `gpt-4-1106-vision-preview` completion model
189pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
190/// `gpt-4` completion model
191pub const GPT_4: &str = "gpt-4";
192/// `gpt-4-0613` completion model
193pub const GPT_4_0613: &str = "gpt-4-0613";
194/// `gpt-4-32k` completion model
195pub const GPT_4_32K: &str = "gpt-4-32k";
196/// `gpt-4-32k-0613` completion model
197pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
198/// `gpt-3.5-turbo` completion model
199pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
200/// `gpt-3.5-turbo-0125` completion model
201pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
202/// `gpt-3.5-turbo-1106` completion model
203pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
204/// `gpt-3.5-turbo-instruct` completion model
205pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
206
207#[derive(Debug, Deserialize)]
208pub struct CompletionResponse {
209    pub id: String,
210    pub object: String,
211    pub created: u64,
212    pub model: String,
213    pub system_fingerprint: Option<String>,
214    pub choices: Vec<Choice>,
215    pub usage: Option<Usage>,
216}
217
218impl From<ApiErrorResponse> for CompletionError {
219    fn from(err: ApiErrorResponse) -> Self {
220        CompletionError::ProviderError(err.message)
221    }
222}
223
224impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
225    type Error = CompletionError;
226
227    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
228        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
229            CompletionError::ResponseError("Response contained no choices".to_owned())
230        })?;
231
232        let mut content = message
233            .content
234            .as_ref()
235            .map(|c| vec![completion::AssistantContent::text(c)])
236            .unwrap_or_default();
237
238        content.extend(message.tool_calls.iter().map(|call| {
239            completion::AssistantContent::tool_call(
240                &call.function.name,
241                &call.function.name,
242                call.function.arguments.clone(),
243            )
244        }));
245
246        let choice = OneOrMany::many(content).map_err(|_| {
247            CompletionError::ResponseError(
248                "Response contained no message or tool call (empty)".to_owned(),
249            )
250        })?;
251
252        Ok(completion::CompletionResponse {
253            choice,
254            raw_response: response,
255        })
256    }
257}
258
259#[derive(Debug, Deserialize)]
260pub struct Choice {
261    pub index: usize,
262    pub message: Message,
263    pub logprobs: Option<serde_json::Value>,
264    pub finish_reason: String,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct Message {
269    pub role: String,
270    pub content: Option<String>,
271    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
272    pub tool_calls: Vec<openai::ToolCall>,
273}
274
275impl TryFrom<Message> for message::Message {
276    type Error = message::MessageError;
277
278    fn try_from(message: Message) -> Result<Self, Self::Error> {
279        let tool_calls: Vec<message::ToolCall> = message
280            .tool_calls
281            .into_iter()
282            .map(|tool_call| tool_call.into())
283            .collect();
284
285        match message.role.as_str() {
286            "user" => Ok(Self::User {
287                content: OneOrMany::one(
288                    message
289                        .content
290                        .map(|content| message::UserContent::text(&content))
291                        .ok_or_else(|| {
292                            message::MessageError::ConversionError("Empty user message".to_string())
293                        })?,
294                ),
295            }),
296            "assistant" => Ok(Self::Assistant {
297                content: OneOrMany::many(
298                    tool_calls
299                        .into_iter()
300                        .map(message::AssistantContent::ToolCall)
301                        .chain(
302                            message
303                                .content
304                                .map(|content| message::AssistantContent::text(&content))
305                                .into_iter(),
306                        ),
307                )
308                .map_err(|_| {
309                    message::MessageError::ConversionError("Empty assistant message".to_string())
310                })?,
311            }),
312            _ => Err(message::MessageError::ConversionError(format!(
313                "Unknown role: {}",
314                message.role
315            ))),
316        }
317    }
318}
319
320impl TryFrom<message::Message> for Message {
321    type Error = message::MessageError;
322
323    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
324        match message {
325            message::Message::User { content } => Ok(Self {
326                role: "user".to_string(),
327                content: content.iter().find_map(|c| match c {
328                    message::UserContent::Text(text) => Some(text.text.clone()),
329                    _ => None,
330                }),
331                tool_calls: vec![],
332            }),
333            message::Message::Assistant { content } => {
334                let mut text_content: Option<String> = None;
335                let mut tool_calls = vec![];
336
337                for c in content.iter() {
338                    match c {
339                        message::AssistantContent::Text(text) => {
340                            text_content = Some(
341                                text_content
342                                    .map(|mut existing| {
343                                        existing.push('\n');
344                                        existing.push_str(&text.text);
345                                        existing
346                                    })
347                                    .unwrap_or_else(|| text.text.clone()),
348                            );
349                        }
350                        message::AssistantContent::ToolCall(tool_call) => {
351                            tool_calls.push(tool_call.clone().into());
352                        }
353                    }
354                }
355
356                Ok(Self {
357                    role: "assistant".to_string(),
358                    content: text_content,
359                    tool_calls,
360                })
361            }
362        }
363    }
364}
365
366#[derive(Clone, Debug, Deserialize, Serialize)]
367pub struct ToolDefinition {
368    pub r#type: String,
369    pub function: completion::ToolDefinition,
370}
371
372impl From<completion::ToolDefinition> for ToolDefinition {
373    fn from(tool: completion::ToolDefinition) -> Self {
374        Self {
375            r#type: "function".into(),
376            function: tool,
377        }
378    }
379}
380
381#[derive(Debug, Deserialize)]
382pub struct Function {
383    pub name: String,
384    pub arguments: String,
385}
386
387#[derive(Clone)]
388pub struct CompletionModel {
389    client: Client,
390    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
391    pub model: String,
392}
393
394impl CompletionModel {
395    pub fn new(client: Client, model: &str) -> Self {
396        Self {
397            client,
398            model: model.to_string(),
399        }
400    }
401}
402
403impl completion::CompletionModel for CompletionModel {
404    type Response = CompletionResponse;
405
406    #[cfg_attr(feature = "worker", worker::send)]
407    async fn completion(
408        &self,
409        completion_request: CompletionRequest,
410    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
411        // Add preamble to chat history (if available)
412        let mut full_history: Vec<Message> = match &completion_request.preamble {
413            Some(preamble) => vec![Message {
414                role: "system".to_string(),
415                content: Some(preamble.to_string()),
416                tool_calls: vec![],
417            }],
418            None => vec![],
419        };
420
421        // Convert prompt to user message
422        let prompt: Message = completion_request.prompt_with_context().try_into()?;
423
424        // Convert existing chat history
425        let chat_history: Vec<Message> = completion_request
426            .chat_history
427            .into_iter()
428            .map(|message| message.try_into())
429            .collect::<Result<Vec<Message>, _>>()?;
430
431        // Combine all messages into a single history
432        full_history.extend(chat_history);
433        full_history.push(prompt);
434
435        let request = if completion_request.tools.is_empty() {
436            json!({
437                "model": self.model,
438                "messages": full_history,
439                "temperature": completion_request.temperature,
440            })
441        } else {
442            json!({
443                "model": self.model,
444                "messages": full_history,
445                "temperature": completion_request.temperature,
446                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
447                "tool_choice": "auto",
448            })
449        };
450
451        let response = self
452            .client
453            .post("/chat/completions")
454            .json(
455                &if let Some(params) = completion_request.additional_params {
456                    json_utils::merge(request, params)
457                } else {
458                    request
459                },
460            )
461            .send()
462            .await?;
463
464        if response.status().is_success() {
465            let t = response.text().await?;
466            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
467
468            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
469                ApiResponse::Ok(response) => {
470                    tracing::info!(target: "rig",
471                        "Galadriel completion token usage: {:?}",
472                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
473                    );
474                    response.try_into()
475                }
476                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
477            }
478        } else {
479            Err(CompletionError::ProviderError(response.text().await?))
480        }
481    }
482}