rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    OneOrMany,
20    completion::{self, CompletionError, CompletionRequest},
21    impl_conversion_traits, json_utils, message,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25
26// ================================================================
27// Main Galadriel Client
28// ================================================================
29const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
30
31#[derive(Clone)]
32pub struct Client {
33    base_url: String,
34    api_key: String,
35    fine_tune_api_key: Option<String>,
36    http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Client")
42            .field("base_url", &self.base_url)
43            .field("http_client", &self.http_client)
44            .field("api_key", &"<REDACTED>")
45            .field("fine_tune_api_key", &"<REDACTED>")
46            .finish()
47    }
48}
49
50impl Client {
51    /// Create a new Galadriel client with the given API key and optional fine-tune API key.
52    pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
53        Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
54    }
55
56    /// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
57    pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
58        Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
59    }
60
61    pub fn from_url_with_optional_key(
62        api_key: &str,
63        base_url: &str,
64        fine_tune_api_key: Option<&str>,
65    ) -> Self {
66        Self {
67            base_url: base_url.to_string(),
68            api_key: api_key.to_string(),
69            fine_tune_api_key: fine_tune_api_key.map(|x| x.to_string()),
70            http_client: reqwest::Client::builder()
71                .build()
72                .expect("Galadriel reqwest client should build"),
73        }
74    }
75
76    /// Use your own `reqwest::Client`.
77    /// The default headers will be automatically attached upon trying to make a request.
78    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
79        self.http_client = client;
80
81        self
82    }
83
84    fn post(&self, path: &str) -> reqwest::RequestBuilder {
85        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
86        let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
87
88        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
89            client = client.header("Fine-Tune-Authorization", fine_tune_key);
90        }
91
92        client
93    }
94}
95
96impl ProviderClient for Client {
97    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
98    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
99    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
100    fn from_env() -> Self {
101        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
102        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
103        Self::new(&api_key, fine_tune_api_key.as_deref())
104    }
105
106    fn from_val(input: crate::client::ProviderValue) -> Self {
107        let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
108        else {
109            panic!("Incorrect provider value type")
110        };
111        Self::new(&api_key, fine_tune_key.as_deref())
112    }
113}
114
115impl CompletionClient for Client {
116    type CompletionModel = CompletionModel;
117
118    /// Create a completion model with the given name.
119    ///
120    /// # Example
121    /// ```
122    /// use rig::providers::galadriel::{Client, self};
123    ///
124    /// // Initialize the Galadriel client
125    /// let galadriel = Client::new("your-galadriel-api-key", None);
126    ///
127    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
128    /// ```
129    fn completion_model(&self, model: &str) -> CompletionModel {
130        CompletionModel::new(self.clone(), model)
131    }
132}
133
134impl_conversion_traits!(
135    AsEmbeddings,
136    AsTranscription,
137    AsImageGeneration,
138    AsAudioGeneration for Client
139);
140
141#[derive(Debug, Deserialize)]
142struct ApiErrorResponse {
143    message: String,
144}
145
146#[derive(Debug, Deserialize)]
147#[serde(untagged)]
148enum ApiResponse<T> {
149    Ok(T),
150    Err(ApiErrorResponse),
151}
152
153#[derive(Clone, Debug, Deserialize)]
154pub struct Usage {
155    pub prompt_tokens: usize,
156    pub total_tokens: usize,
157}
158
159impl std::fmt::Display for Usage {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(
162            f,
163            "Prompt tokens: {} Total tokens: {}",
164            self.prompt_tokens, self.total_tokens
165        )
166    }
167}
168
169// ================================================================
170// Galadriel Completion API
171// ================================================================
172/// `o1-preview` completion model
173pub const O1_PREVIEW: &str = "o1-preview";
174/// `o1-preview-2024-09-12` completion model
175pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
176/// `o1-mini completion model
177pub const O1_MINI: &str = "o1-mini";
178/// `o1-mini-2024-09-12` completion model
179pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
180/// `gpt-4o` completion model
181pub const GPT_4O: &str = "gpt-4o";
182/// `gpt-4o-2024-05-13` completion model
183pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
184/// `gpt-4-turbo` completion model
185pub const GPT_4_TURBO: &str = "gpt-4-turbo";
186/// `gpt-4-turbo-2024-04-09` completion model
187pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
188/// `gpt-4-turbo-preview` completion model
189pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
190/// `gpt-4-0125-preview` completion model
191pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
192/// `gpt-4-1106-preview` completion model
193pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
194/// `gpt-4-vision-preview` completion model
195pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
196/// `gpt-4-1106-vision-preview` completion model
197pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
198/// `gpt-4` completion model
199pub const GPT_4: &str = "gpt-4";
200/// `gpt-4-0613` completion model
201pub const GPT_4_0613: &str = "gpt-4-0613";
202/// `gpt-4-32k` completion model
203pub const GPT_4_32K: &str = "gpt-4-32k";
204/// `gpt-4-32k-0613` completion model
205pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
206/// `gpt-3.5-turbo` completion model
207pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
208/// `gpt-3.5-turbo-0125` completion model
209pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
210/// `gpt-3.5-turbo-1106` completion model
211pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
212/// `gpt-3.5-turbo-instruct` completion model
213pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
214
215#[derive(Debug, Deserialize)]
216pub struct CompletionResponse {
217    pub id: String,
218    pub object: String,
219    pub created: u64,
220    pub model: String,
221    pub system_fingerprint: Option<String>,
222    pub choices: Vec<Choice>,
223    pub usage: Option<Usage>,
224}
225
226impl From<ApiErrorResponse> for CompletionError {
227    fn from(err: ApiErrorResponse) -> Self {
228        CompletionError::ProviderError(err.message)
229    }
230}
231
232impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
233    type Error = CompletionError;
234
235    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
236        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
237            CompletionError::ResponseError("Response contained no choices".to_owned())
238        })?;
239
240        let mut content = message
241            .content
242            .as_ref()
243            .map(|c| vec![completion::AssistantContent::text(c)])
244            .unwrap_or_default();
245
246        content.extend(message.tool_calls.iter().map(|call| {
247            completion::AssistantContent::tool_call(
248                &call.function.name,
249                &call.function.name,
250                call.function.arguments.clone(),
251            )
252        }));
253
254        let choice = OneOrMany::many(content).map_err(|_| {
255            CompletionError::ResponseError(
256                "Response contained no message or tool call (empty)".to_owned(),
257            )
258        })?;
259        let usage = response
260            .usage
261            .as_ref()
262            .map(|usage| completion::Usage {
263                input_tokens: usage.prompt_tokens as u64,
264                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
265                total_tokens: usage.total_tokens as u64,
266            })
267            .unwrap_or_default();
268
269        Ok(completion::CompletionResponse {
270            choice,
271            usage,
272            raw_response: response,
273        })
274    }
275}
276
277#[derive(Debug, Deserialize)]
278pub struct Choice {
279    pub index: usize,
280    pub message: Message,
281    pub logprobs: Option<serde_json::Value>,
282    pub finish_reason: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Message {
287    pub role: String,
288    pub content: Option<String>,
289    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
290    pub tool_calls: Vec<openai::ToolCall>,
291}
292
293impl TryFrom<Message> for message::Message {
294    type Error = message::MessageError;
295
296    fn try_from(message: Message) -> Result<Self, Self::Error> {
297        let tool_calls: Vec<message::ToolCall> = message
298            .tool_calls
299            .into_iter()
300            .map(|tool_call| tool_call.into())
301            .collect();
302
303        match message.role.as_str() {
304            "user" => Ok(Self::User {
305                content: OneOrMany::one(
306                    message
307                        .content
308                        .map(|content| message::UserContent::text(&content))
309                        .ok_or_else(|| {
310                            message::MessageError::ConversionError("Empty user message".to_string())
311                        })?,
312                ),
313            }),
314            "assistant" => Ok(Self::Assistant {
315                id: None,
316                content: OneOrMany::many(
317                    tool_calls
318                        .into_iter()
319                        .map(message::AssistantContent::ToolCall)
320                        .chain(
321                            message
322                                .content
323                                .map(|content| message::AssistantContent::text(&content))
324                                .into_iter(),
325                        ),
326                )
327                .map_err(|_| {
328                    message::MessageError::ConversionError("Empty assistant message".to_string())
329                })?,
330            }),
331            _ => Err(message::MessageError::ConversionError(format!(
332                "Unknown role: {}",
333                message.role
334            ))),
335        }
336    }
337}
338
339impl TryFrom<message::Message> for Message {
340    type Error = message::MessageError;
341
342    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
343        match message {
344            message::Message::User { content } => Ok(Self {
345                role: "user".to_string(),
346                content: content.iter().find_map(|c| match c {
347                    message::UserContent::Text(text) => Some(text.text.clone()),
348                    _ => None,
349                }),
350                tool_calls: vec![],
351            }),
352            message::Message::Assistant { content, .. } => {
353                let mut text_content: Option<String> = None;
354                let mut tool_calls = vec![];
355
356                for c in content.iter() {
357                    match c {
358                        message::AssistantContent::Text(text) => {
359                            text_content = Some(
360                                text_content
361                                    .map(|mut existing| {
362                                        existing.push('\n');
363                                        existing.push_str(&text.text);
364                                        existing
365                                    })
366                                    .unwrap_or_else(|| text.text.clone()),
367                            );
368                        }
369                        message::AssistantContent::ToolCall(tool_call) => {
370                            tool_calls.push(tool_call.clone().into());
371                        }
372                    }
373                }
374
375                Ok(Self {
376                    role: "assistant".to_string(),
377                    content: text_content,
378                    tool_calls,
379                })
380            }
381        }
382    }
383}
384
385#[derive(Clone, Debug, Deserialize, Serialize)]
386pub struct ToolDefinition {
387    pub r#type: String,
388    pub function: completion::ToolDefinition,
389}
390
391impl From<completion::ToolDefinition> for ToolDefinition {
392    fn from(tool: completion::ToolDefinition) -> Self {
393        Self {
394            r#type: "function".into(),
395            function: tool,
396        }
397    }
398}
399
400#[derive(Debug, Deserialize)]
401pub struct Function {
402    pub name: String,
403    pub arguments: String,
404}
405
406#[derive(Clone)]
407pub struct CompletionModel {
408    client: Client,
409    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
410    pub model: String,
411}
412
413impl CompletionModel {
414    pub(crate) fn create_completion_request(
415        &self,
416        completion_request: CompletionRequest,
417    ) -> Result<Value, CompletionError> {
418        // Build up the order of messages (context, chat_history, prompt)
419        let mut partial_history = vec![];
420        if let Some(docs) = completion_request.normalized_documents() {
421            partial_history.push(docs);
422        }
423        partial_history.extend(completion_request.chat_history);
424
425        // Add preamble to chat history (if available)
426        let mut full_history: Vec<Message> = match &completion_request.preamble {
427            Some(preamble) => vec![Message {
428                role: "system".to_string(),
429                content: Some(preamble.to_string()),
430                tool_calls: vec![],
431            }],
432            None => vec![],
433        };
434
435        // Convert and extend the rest of the history
436        full_history.extend(
437            partial_history
438                .into_iter()
439                .map(message::Message::try_into)
440                .collect::<Result<Vec<Message>, _>>()?,
441        );
442
443        let request = if completion_request.tools.is_empty() {
444            json!({
445                "model": self.model,
446                "messages": full_history,
447                "temperature": completion_request.temperature,
448            })
449        } else {
450            json!({
451                "model": self.model,
452                "messages": full_history,
453                "temperature": completion_request.temperature,
454                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
455                "tool_choice": "auto",
456            })
457        };
458
459        let request = if let Some(params) = completion_request.additional_params {
460            json_utils::merge(request, params)
461        } else {
462            request
463        };
464
465        Ok(request)
466    }
467}
468
469impl CompletionModel {
470    pub fn new(client: Client, model: &str) -> Self {
471        Self {
472            client,
473            model: model.to_string(),
474        }
475    }
476}
477
478impl completion::CompletionModel for CompletionModel {
479    type Response = CompletionResponse;
480    type StreamingResponse = openai::StreamingCompletionResponse;
481
482    #[cfg_attr(feature = "worker", worker::send)]
483    async fn completion(
484        &self,
485        completion_request: CompletionRequest,
486    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
487        let request = self.create_completion_request(completion_request)?;
488
489        let response = self
490            .client
491            .post("/chat/completions")
492            .json(&request)
493            .send()
494            .await?;
495
496        if response.status().is_success() {
497            let t = response.text().await?;
498            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
499
500            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
501                ApiResponse::Ok(response) => {
502                    tracing::info!(target: "rig",
503                        "Galadriel completion token usage: {:?}",
504                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
505                    );
506                    response.try_into()
507                }
508                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
509            }
510        } else {
511            Err(CompletionError::ProviderError(response.text().await?))
512        }
513    }
514
515    #[cfg_attr(feature = "worker", worker::send)]
516    async fn stream(
517        &self,
518        request: CompletionRequest,
519    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
520        let mut request = self.create_completion_request(request)?;
521
522        request = merge(
523            request,
524            json!({"stream": true, "stream_options": {"include_usage": true}}),
525        );
526
527        let builder = self.client.post("/chat/completions").json(&request);
528
529        send_compatible_streaming_request(builder).await
530    }
531}