rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{
15    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
16};
17use crate::json_utils::merge;
18use crate::message::MessageError;
19use crate::providers::openai::send_compatible_streaming_request;
20use crate::streaming::StreamingCompletionResponse;
21use crate::{
22    OneOrMany,
23    completion::{self, CompletionError, CompletionRequest},
24    impl_conversion_traits, json_utils, message,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28
29// ================================================================
30// Main Galadriel Client
31// ================================================================
32const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
33
34pub struct ClientBuilder<'a> {
35    api_key: &'a str,
36    fine_tune_api_key: Option<&'a str>,
37    base_url: &'a str,
38    http_client: Option<reqwest::Client>,
39}
40
41impl<'a> ClientBuilder<'a> {
42    pub fn new(api_key: &'a str) -> Self {
43        Self {
44            api_key,
45            fine_tune_api_key: None,
46            base_url: GALADRIEL_API_BASE_URL,
47            http_client: None,
48        }
49    }
50
51    pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
52        self.fine_tune_api_key = Some(fine_tune_api_key);
53        self
54    }
55
56    pub fn base_url(mut self, base_url: &'a str) -> Self {
57        self.base_url = base_url;
58        self
59    }
60
61    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
62        self.http_client = Some(client);
63        self
64    }
65
66    pub fn build(self) -> Result<Client, ClientBuilderError> {
67        let http_client = if let Some(http_client) = self.http_client {
68            http_client
69        } else {
70            reqwest::Client::builder().build()?
71        };
72
73        Ok(Client {
74            base_url: self.base_url.to_string(),
75            api_key: self.api_key.to_string(),
76            fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
77            http_client,
78        })
79    }
80}
81#[derive(Clone)]
82pub struct Client {
83    base_url: String,
84    api_key: String,
85    fine_tune_api_key: Option<String>,
86    http_client: reqwest::Client,
87}
88
89impl std::fmt::Debug for Client {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("Client")
92            .field("base_url", &self.base_url)
93            .field("http_client", &self.http_client)
94            .field("api_key", &"<REDACTED>")
95            .field("fine_tune_api_key", &"<REDACTED>")
96            .finish()
97    }
98}
99
100impl Client {
101    /// Create a new Galadriel client builder.
102    ///
103    /// # Example
104    /// ```
105    /// use rig::providers::galadriel::{ClientBuilder, self};
106    ///
107    /// // Initialize the Galadriel client
108    /// let galadriel = Client::builder("your-galadriel-api-key")
109    ///    .build()
110    /// ```
111    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
112        ClientBuilder::new(api_key)
113    }
114
115    /// Create a new Galadriel client. For more control, use the `builder` method.
116    ///
117    /// # Panics
118    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
119    pub fn new(api_key: &str) -> Self {
120        Self::builder(api_key)
121            .build()
122            .expect("Galadriel client should build")
123    }
124
125    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
126        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
127        let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
128
129        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
130            client = client.header("Fine-Tune-Authorization", fine_tune_key);
131        }
132
133        client
134    }
135}
136
137impl ProviderClient for Client {
138    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
139    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
140    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
141    fn from_env() -> Self {
142        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
143        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
144        let mut builder = Self::builder(&api_key);
145        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
146            builder = builder.fine_tune_api_key(fine_tune_api_key);
147        }
148        builder.build().expect("Galadriel client should build")
149    }
150
151    fn from_val(input: crate::client::ProviderValue) -> Self {
152        let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
153        else {
154            panic!("Incorrect provider value type")
155        };
156        let mut builder = Self::builder(&api_key);
157        if let Some(fine_tune_key) = fine_tune_key.as_deref() {
158            builder = builder.fine_tune_api_key(fine_tune_key);
159        }
160        builder.build().expect("Galadriel client should build")
161    }
162}
163
164impl CompletionClient for Client {
165    type CompletionModel = CompletionModel;
166
167    /// Create a completion model with the given name.
168    ///
169    /// # Example
170    /// ```
171    /// use rig::providers::galadriel::{Client, self};
172    ///
173    /// // Initialize the Galadriel client
174    /// let galadriel = Client::new("your-galadriel-api-key", None);
175    ///
176    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
177    /// ```
178    fn completion_model(&self, model: &str) -> CompletionModel {
179        CompletionModel::new(self.clone(), model)
180    }
181}
182
183impl VerifyClient for Client {
184    #[cfg_attr(feature = "worker", worker::send)]
185    async fn verify(&self) -> Result<(), VerifyError> {
186        // Could not find an API endpoint to verify the API key
187        Ok(())
188    }
189}
190
191impl_conversion_traits!(
192    AsEmbeddings,
193    AsTranscription,
194    AsImageGeneration,
195    AsAudioGeneration for Client
196);
197
198#[derive(Debug, Deserialize)]
199struct ApiErrorResponse {
200    message: String,
201}
202
203#[derive(Debug, Deserialize)]
204#[serde(untagged)]
205enum ApiResponse<T> {
206    Ok(T),
207    Err(ApiErrorResponse),
208}
209
210#[derive(Clone, Debug, Deserialize, Serialize)]
211pub struct Usage {
212    pub prompt_tokens: usize,
213    pub total_tokens: usize,
214}
215
216impl std::fmt::Display for Usage {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        write!(
219            f,
220            "Prompt tokens: {} Total tokens: {}",
221            self.prompt_tokens, self.total_tokens
222        )
223    }
224}
225
226// ================================================================
227// Galadriel Completion API
228// ================================================================
229/// `o1-preview` completion model
230pub const O1_PREVIEW: &str = "o1-preview";
231/// `o1-preview-2024-09-12` completion model
232pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
233/// `o1-mini completion model
234pub const O1_MINI: &str = "o1-mini";
235/// `o1-mini-2024-09-12` completion model
236pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
237/// `gpt-4o` completion model
238pub const GPT_4O: &str = "gpt-4o";
239/// `gpt-4o-2024-05-13` completion model
240pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
241/// `gpt-4-turbo` completion model
242pub const GPT_4_TURBO: &str = "gpt-4-turbo";
243/// `gpt-4-turbo-2024-04-09` completion model
244pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
245/// `gpt-4-turbo-preview` completion model
246pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
247/// `gpt-4-0125-preview` completion model
248pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
249/// `gpt-4-1106-preview` completion model
250pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
251/// `gpt-4-vision-preview` completion model
252pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
253/// `gpt-4-1106-vision-preview` completion model
254pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
255/// `gpt-4` completion model
256pub const GPT_4: &str = "gpt-4";
257/// `gpt-4-0613` completion model
258pub const GPT_4_0613: &str = "gpt-4-0613";
259/// `gpt-4-32k` completion model
260pub const GPT_4_32K: &str = "gpt-4-32k";
261/// `gpt-4-32k-0613` completion model
262pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
263/// `gpt-3.5-turbo` completion model
264pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
265/// `gpt-3.5-turbo-0125` completion model
266pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
267/// `gpt-3.5-turbo-1106` completion model
268pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
269/// `gpt-3.5-turbo-instruct` completion model
270pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
271
272#[derive(Debug, Deserialize, Serialize)]
273pub struct CompletionResponse {
274    pub id: String,
275    pub object: String,
276    pub created: u64,
277    pub model: String,
278    pub system_fingerprint: Option<String>,
279    pub choices: Vec<Choice>,
280    pub usage: Option<Usage>,
281}
282
283impl From<ApiErrorResponse> for CompletionError {
284    fn from(err: ApiErrorResponse) -> Self {
285        CompletionError::ProviderError(err.message)
286    }
287}
288
289impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
290    type Error = CompletionError;
291
292    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
293        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
294            CompletionError::ResponseError("Response contained no choices".to_owned())
295        })?;
296
297        let mut content = message
298            .content
299            .as_ref()
300            .map(|c| vec![completion::AssistantContent::text(c)])
301            .unwrap_or_default();
302
303        content.extend(message.tool_calls.iter().map(|call| {
304            completion::AssistantContent::tool_call(
305                &call.function.name,
306                &call.function.name,
307                call.function.arguments.clone(),
308            )
309        }));
310
311        let choice = OneOrMany::many(content).map_err(|_| {
312            CompletionError::ResponseError(
313                "Response contained no message or tool call (empty)".to_owned(),
314            )
315        })?;
316        let usage = response
317            .usage
318            .as_ref()
319            .map(|usage| completion::Usage {
320                input_tokens: usage.prompt_tokens as u64,
321                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
322                total_tokens: usage.total_tokens as u64,
323            })
324            .unwrap_or_default();
325
326        Ok(completion::CompletionResponse {
327            choice,
328            usage,
329            raw_response: response,
330        })
331    }
332}
333
334#[derive(Debug, Deserialize, Serialize)]
335pub struct Choice {
336    pub index: usize,
337    pub message: Message,
338    pub logprobs: Option<serde_json::Value>,
339    pub finish_reason: String,
340}
341
342#[derive(Debug, Serialize, Deserialize)]
343pub struct Message {
344    pub role: String,
345    pub content: Option<String>,
346    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
347    pub tool_calls: Vec<openai::ToolCall>,
348}
349
350impl TryFrom<Message> for message::Message {
351    type Error = message::MessageError;
352
353    fn try_from(message: Message) -> Result<Self, Self::Error> {
354        let tool_calls: Vec<message::ToolCall> = message
355            .tool_calls
356            .into_iter()
357            .map(|tool_call| tool_call.into())
358            .collect();
359
360        match message.role.as_str() {
361            "user" => Ok(Self::User {
362                content: OneOrMany::one(
363                    message
364                        .content
365                        .map(|content| message::UserContent::text(&content))
366                        .ok_or_else(|| {
367                            message::MessageError::ConversionError("Empty user message".to_string())
368                        })?,
369                ),
370            }),
371            "assistant" => Ok(Self::Assistant {
372                id: None,
373                content: OneOrMany::many(
374                    tool_calls
375                        .into_iter()
376                        .map(message::AssistantContent::ToolCall)
377                        .chain(
378                            message
379                                .content
380                                .map(|content| message::AssistantContent::text(&content))
381                                .into_iter(),
382                        ),
383                )
384                .map_err(|_| {
385                    message::MessageError::ConversionError("Empty assistant message".to_string())
386                })?,
387            }),
388            _ => Err(message::MessageError::ConversionError(format!(
389                "Unknown role: {}",
390                message.role
391            ))),
392        }
393    }
394}
395
396impl TryFrom<message::Message> for Message {
397    type Error = message::MessageError;
398
399    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
400        match message {
401            message::Message::User { content } => Ok(Self {
402                role: "user".to_string(),
403                content: content.iter().find_map(|c| match c {
404                    message::UserContent::Text(text) => Some(text.text.clone()),
405                    _ => None,
406                }),
407                tool_calls: vec![],
408            }),
409            message::Message::Assistant { content, .. } => {
410                let mut text_content: Option<String> = None;
411                let mut tool_calls = vec![];
412
413                for c in content.iter() {
414                    match c {
415                        message::AssistantContent::Text(text) => {
416                            text_content = Some(
417                                text_content
418                                    .map(|mut existing| {
419                                        existing.push('\n');
420                                        existing.push_str(&text.text);
421                                        existing
422                                    })
423                                    .unwrap_or_else(|| text.text.clone()),
424                            );
425                        }
426                        message::AssistantContent::ToolCall(tool_call) => {
427                            tool_calls.push(tool_call.clone().into());
428                        }
429                        message::AssistantContent::Reasoning(_) => {
430                            return Err(MessageError::ConversionError(
431                                "Galadriel currently doesn't support reasoning.".into(),
432                            ));
433                        }
434                    }
435                }
436
437                Ok(Self {
438                    role: "assistant".to_string(),
439                    content: text_content,
440                    tool_calls,
441                })
442            }
443        }
444    }
445}
446
447#[derive(Clone, Debug, Deserialize, Serialize)]
448pub struct ToolDefinition {
449    pub r#type: String,
450    pub function: completion::ToolDefinition,
451}
452
453impl From<completion::ToolDefinition> for ToolDefinition {
454    fn from(tool: completion::ToolDefinition) -> Self {
455        Self {
456            r#type: "function".into(),
457            function: tool,
458        }
459    }
460}
461
462#[derive(Debug, Deserialize)]
463pub struct Function {
464    pub name: String,
465    pub arguments: String,
466}
467
468#[derive(Clone)]
469pub struct CompletionModel {
470    client: Client,
471    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
472    pub model: String,
473}
474
475impl CompletionModel {
476    pub(crate) fn create_completion_request(
477        &self,
478        completion_request: CompletionRequest,
479    ) -> Result<Value, CompletionError> {
480        // Build up the order of messages (context, chat_history, prompt)
481        let mut partial_history = vec![];
482        if let Some(docs) = completion_request.normalized_documents() {
483            partial_history.push(docs);
484        }
485        partial_history.extend(completion_request.chat_history);
486
487        // Add preamble to chat history (if available)
488        let mut full_history: Vec<Message> = match &completion_request.preamble {
489            Some(preamble) => vec![Message {
490                role: "system".to_string(),
491                content: Some(preamble.to_string()),
492                tool_calls: vec![],
493            }],
494            None => vec![],
495        };
496
497        // Convert and extend the rest of the history
498        full_history.extend(
499            partial_history
500                .into_iter()
501                .map(message::Message::try_into)
502                .collect::<Result<Vec<Message>, _>>()?,
503        );
504
505        let request = if completion_request.tools.is_empty() {
506            json!({
507                "model": self.model,
508                "messages": full_history,
509                "temperature": completion_request.temperature,
510            })
511        } else {
512            json!({
513                "model": self.model,
514                "messages": full_history,
515                "temperature": completion_request.temperature,
516                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
517                "tool_choice": "auto",
518            })
519        };
520
521        let request = if let Some(params) = completion_request.additional_params {
522            json_utils::merge(request, params)
523        } else {
524            request
525        };
526
527        Ok(request)
528    }
529}
530
531impl CompletionModel {
532    pub fn new(client: Client, model: &str) -> Self {
533        Self {
534            client,
535            model: model.to_string(),
536        }
537    }
538}
539
540impl completion::CompletionModel for CompletionModel {
541    type Response = CompletionResponse;
542    type StreamingResponse = openai::StreamingCompletionResponse;
543
544    #[cfg_attr(feature = "worker", worker::send)]
545    async fn completion(
546        &self,
547        completion_request: CompletionRequest,
548    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
549        let request = self.create_completion_request(completion_request)?;
550
551        let response = self
552            .client
553            .post("/chat/completions")
554            .json(&request)
555            .send()
556            .await?;
557
558        if response.status().is_success() {
559            let t = response.text().await?;
560            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
561
562            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
563                ApiResponse::Ok(response) => {
564                    tracing::info!(target: "rig",
565                        "Galadriel completion token usage: {:?}",
566                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
567                    );
568                    response.try_into()
569                }
570                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
571            }
572        } else {
573            Err(CompletionError::ProviderError(response.text().await?))
574        }
575    }
576
577    #[cfg_attr(feature = "worker", worker::send)]
578    async fn stream(
579        &self,
580        request: CompletionRequest,
581    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
582        let mut request = self.create_completion_request(request)?;
583
584        request = merge(
585            request,
586            json!({"stream": true, "stream_options": {"include_usage": true}}),
587        );
588
589        let builder = self.client.post("/chat/completions").json(&request);
590
591        send_compatible_streaming_request(builder).await
592    }
593}