rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    completion::{self, CompletionError, CompletionRequest},
20    impl_conversion_traits, json_utils, message, OneOrMany,
21};
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24
25// ================================================================
26// Main Galadriel Client
27// ================================================================
28const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
29
30#[derive(Clone, Debug)]
31pub struct Client {
32    base_url: String,
33    http_client: reqwest::Client,
34}
35
36impl Client {
37    /// Create a new Galadriel client with the given API key and optional fine-tune API key.
38    pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
39        Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
40    }
41
42    /// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
43    pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
44        Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
45    }
46
47    pub fn from_url_with_optional_key(
48        api_key: &str,
49        base_url: &str,
50        fine_tune_api_key: Option<&str>,
51    ) -> Self {
52        Self {
53            base_url: base_url.to_string(),
54            http_client: reqwest::Client::builder()
55                .default_headers({
56                    let mut headers = reqwest::header::HeaderMap::new();
57                    headers.insert(
58                        "Authorization",
59                        format!("Bearer {api_key}")
60                            .parse()
61                            .expect("Bearer token should parse"),
62                    );
63                    if let Some(key) = fine_tune_api_key {
64                        headers.insert(
65                            "Fine-Tune-Authorization",
66                            format!("Bearer {key}")
67                                .parse()
68                                .expect("Bearer token should parse"),
69                        );
70                    }
71                    headers
72                })
73                .build()
74                .expect("Galadriel reqwest client should build"),
75        }
76    }
77
78    fn post(&self, path: &str) -> reqwest::RequestBuilder {
79        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
80        self.http_client.post(url)
81    }
82}
83
84impl ProviderClient for Client {
85    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
86    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
87    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
88    fn from_env() -> Self {
89        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
90        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
91        Self::new(&api_key, fine_tune_api_key.as_deref())
92    }
93}
94
95impl CompletionClient for Client {
96    type CompletionModel = CompletionModel;
97
98    /// Create a completion model with the given name.
99    ///
100    /// # Example
101    /// ```
102    /// use rig::providers::galadriel::{Client, self};
103    ///
104    /// // Initialize the Galadriel client
105    /// let galadriel = Client::new("your-galadriel-api-key", None);
106    ///
107    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
108    /// ```
109    fn completion_model(&self, model: &str) -> CompletionModel {
110        CompletionModel::new(self.clone(), model)
111    }
112}
113
114impl_conversion_traits!(
115    AsEmbeddings,
116    AsTranscription,
117    AsImageGeneration,
118    AsAudioGeneration for Client
119);
120
121#[derive(Debug, Deserialize)]
122struct ApiErrorResponse {
123    message: String,
124}
125
126#[derive(Debug, Deserialize)]
127#[serde(untagged)]
128enum ApiResponse<T> {
129    Ok(T),
130    Err(ApiErrorResponse),
131}
132
133#[derive(Clone, Debug, Deserialize)]
134pub struct Usage {
135    pub prompt_tokens: usize,
136    pub total_tokens: usize,
137}
138
139impl std::fmt::Display for Usage {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        write!(
142            f,
143            "Prompt tokens: {} Total tokens: {}",
144            self.prompt_tokens, self.total_tokens
145        )
146    }
147}
148
149// ================================================================
150// Galadriel Completion API
151// ================================================================
152/// `o1-preview` completion model
153pub const O1_PREVIEW: &str = "o1-preview";
154/// `o1-preview-2024-09-12` completion model
155pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
156/// `o1-mini completion model
157pub const O1_MINI: &str = "o1-mini";
158/// `o1-mini-2024-09-12` completion model
159pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
160/// `gpt-4o` completion model
161pub const GPT_4O: &str = "gpt-4o";
162/// `gpt-4o-2024-05-13` completion model
163pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
164/// `gpt-4-turbo` completion model
165pub const GPT_4_TURBO: &str = "gpt-4-turbo";
166/// `gpt-4-turbo-2024-04-09` completion model
167pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
168/// `gpt-4-turbo-preview` completion model
169pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
170/// `gpt-4-0125-preview` completion model
171pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
172/// `gpt-4-1106-preview` completion model
173pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
174/// `gpt-4-vision-preview` completion model
175pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
176/// `gpt-4-1106-vision-preview` completion model
177pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
178/// `gpt-4` completion model
179pub const GPT_4: &str = "gpt-4";
180/// `gpt-4-0613` completion model
181pub const GPT_4_0613: &str = "gpt-4-0613";
182/// `gpt-4-32k` completion model
183pub const GPT_4_32K: &str = "gpt-4-32k";
184/// `gpt-4-32k-0613` completion model
185pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
186/// `gpt-3.5-turbo` completion model
187pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
188/// `gpt-3.5-turbo-0125` completion model
189pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
190/// `gpt-3.5-turbo-1106` completion model
191pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
192/// `gpt-3.5-turbo-instruct` completion model
193pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
194
195#[derive(Debug, Deserialize)]
196pub struct CompletionResponse {
197    pub id: String,
198    pub object: String,
199    pub created: u64,
200    pub model: String,
201    pub system_fingerprint: Option<String>,
202    pub choices: Vec<Choice>,
203    pub usage: Option<Usage>,
204}
205
206impl From<ApiErrorResponse> for CompletionError {
207    fn from(err: ApiErrorResponse) -> Self {
208        CompletionError::ProviderError(err.message)
209    }
210}
211
212impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
213    type Error = CompletionError;
214
215    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
216        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
217            CompletionError::ResponseError("Response contained no choices".to_owned())
218        })?;
219
220        let mut content = message
221            .content
222            .as_ref()
223            .map(|c| vec![completion::AssistantContent::text(c)])
224            .unwrap_or_default();
225
226        content.extend(message.tool_calls.iter().map(|call| {
227            completion::AssistantContent::tool_call(
228                &call.function.name,
229                &call.function.name,
230                call.function.arguments.clone(),
231            )
232        }));
233
234        let choice = OneOrMany::many(content).map_err(|_| {
235            CompletionError::ResponseError(
236                "Response contained no message or tool call (empty)".to_owned(),
237            )
238        })?;
239
240        Ok(completion::CompletionResponse {
241            choice,
242            raw_response: response,
243        })
244    }
245}
246
247#[derive(Debug, Deserialize)]
248pub struct Choice {
249    pub index: usize,
250    pub message: Message,
251    pub logprobs: Option<serde_json::Value>,
252    pub finish_reason: String,
253}
254
255#[derive(Debug, Serialize, Deserialize)]
256pub struct Message {
257    pub role: String,
258    pub content: Option<String>,
259    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
260    pub tool_calls: Vec<openai::ToolCall>,
261}
262
263impl TryFrom<Message> for message::Message {
264    type Error = message::MessageError;
265
266    fn try_from(message: Message) -> Result<Self, Self::Error> {
267        let tool_calls: Vec<message::ToolCall> = message
268            .tool_calls
269            .into_iter()
270            .map(|tool_call| tool_call.into())
271            .collect();
272
273        match message.role.as_str() {
274            "user" => Ok(Self::User {
275                content: OneOrMany::one(
276                    message
277                        .content
278                        .map(|content| message::UserContent::text(&content))
279                        .ok_or_else(|| {
280                            message::MessageError::ConversionError("Empty user message".to_string())
281                        })?,
282                ),
283            }),
284            "assistant" => Ok(Self::Assistant {
285                content: OneOrMany::many(
286                    tool_calls
287                        .into_iter()
288                        .map(message::AssistantContent::ToolCall)
289                        .chain(
290                            message
291                                .content
292                                .map(|content| message::AssistantContent::text(&content))
293                                .into_iter(),
294                        ),
295                )
296                .map_err(|_| {
297                    message::MessageError::ConversionError("Empty assistant message".to_string())
298                })?,
299            }),
300            _ => Err(message::MessageError::ConversionError(format!(
301                "Unknown role: {}",
302                message.role
303            ))),
304        }
305    }
306}
307
308impl TryFrom<message::Message> for Message {
309    type Error = message::MessageError;
310
311    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
312        match message {
313            message::Message::User { content } => Ok(Self {
314                role: "user".to_string(),
315                content: content.iter().find_map(|c| match c {
316                    message::UserContent::Text(text) => Some(text.text.clone()),
317                    _ => None,
318                }),
319                tool_calls: vec![],
320            }),
321            message::Message::Assistant { content } => {
322                let mut text_content: Option<String> = None;
323                let mut tool_calls = vec![];
324
325                for c in content.iter() {
326                    match c {
327                        message::AssistantContent::Text(text) => {
328                            text_content = Some(
329                                text_content
330                                    .map(|mut existing| {
331                                        existing.push('\n');
332                                        existing.push_str(&text.text);
333                                        existing
334                                    })
335                                    .unwrap_or_else(|| text.text.clone()),
336                            );
337                        }
338                        message::AssistantContent::ToolCall(tool_call) => {
339                            tool_calls.push(tool_call.clone().into());
340                        }
341                    }
342                }
343
344                Ok(Self {
345                    role: "assistant".to_string(),
346                    content: text_content,
347                    tool_calls,
348                })
349            }
350        }
351    }
352}
353
354#[derive(Clone, Debug, Deserialize, Serialize)]
355pub struct ToolDefinition {
356    pub r#type: String,
357    pub function: completion::ToolDefinition,
358}
359
360impl From<completion::ToolDefinition> for ToolDefinition {
361    fn from(tool: completion::ToolDefinition) -> Self {
362        Self {
363            r#type: "function".into(),
364            function: tool,
365        }
366    }
367}
368
369#[derive(Debug, Deserialize)]
370pub struct Function {
371    pub name: String,
372    pub arguments: String,
373}
374
375#[derive(Clone)]
376pub struct CompletionModel {
377    client: Client,
378    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
379    pub model: String,
380}
381
382impl CompletionModel {
383    pub(crate) fn create_completion_request(
384        &self,
385        completion_request: CompletionRequest,
386    ) -> Result<Value, CompletionError> {
387        // Build up the order of messages (context, chat_history, prompt)
388        let mut partial_history = vec![];
389        if let Some(docs) = completion_request.normalized_documents() {
390            partial_history.push(docs);
391        }
392        partial_history.extend(completion_request.chat_history);
393
394        // Add preamble to chat history (if available)
395        let mut full_history: Vec<Message> = match &completion_request.preamble {
396            Some(preamble) => vec![Message {
397                role: "system".to_string(),
398                content: Some(preamble.to_string()),
399                tool_calls: vec![],
400            }],
401            None => vec![],
402        };
403
404        // Convert and extend the rest of the history
405        full_history.extend(
406            partial_history
407                .into_iter()
408                .map(message::Message::try_into)
409                .collect::<Result<Vec<Message>, _>>()?,
410        );
411
412        let request = if completion_request.tools.is_empty() {
413            json!({
414                "model": self.model,
415                "messages": full_history,
416                "temperature": completion_request.temperature,
417            })
418        } else {
419            json!({
420                "model": self.model,
421                "messages": full_history,
422                "temperature": completion_request.temperature,
423                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
424                "tool_choice": "auto",
425            })
426        };
427
428        let request = if let Some(params) = completion_request.additional_params {
429            json_utils::merge(request, params)
430        } else {
431            request
432        };
433
434        Ok(request)
435    }
436}
437
438impl CompletionModel {
439    pub fn new(client: Client, model: &str) -> Self {
440        Self {
441            client,
442            model: model.to_string(),
443        }
444    }
445}
446
447impl completion::CompletionModel for CompletionModel {
448    type Response = CompletionResponse;
449    type StreamingResponse = openai::StreamingCompletionResponse;
450
451    #[cfg_attr(feature = "worker", worker::send)]
452    async fn completion(
453        &self,
454        completion_request: CompletionRequest,
455    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
456        let request = self.create_completion_request(completion_request)?;
457
458        let response = self
459            .client
460            .post("/chat/completions")
461            .json(&request)
462            .send()
463            .await?;
464
465        if response.status().is_success() {
466            let t = response.text().await?;
467            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
468
469            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
470                ApiResponse::Ok(response) => {
471                    tracing::info!(target: "rig",
472                        "Galadriel completion token usage: {:?}",
473                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
474                    );
475                    response.try_into()
476                }
477                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
478            }
479        } else {
480            Err(CompletionError::ProviderError(response.text().await?))
481        }
482    }
483
484    #[cfg_attr(feature = "worker", worker::send)]
485    async fn stream(
486        &self,
487        request: CompletionRequest,
488    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
489        let mut request = self.create_completion_request(request)?;
490
491        request = merge(
492            request,
493            json!({"stream": true, "stream_options": {"include_usage": true}}),
494        );
495
496        let builder = self.client.post("/chat/completions").json(&request);
497
498        send_compatible_streaming_request(builder).await
499    }
500}