rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::message::MessageError;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    impl_conversion_traits, json_utils, message,
23};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26
27// ================================================================
28// Main Galadriel Client
29// ================================================================
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
31
32pub struct ClientBuilder<'a> {
33    api_key: &'a str,
34    fine_tune_api_key: Option<&'a str>,
35    base_url: &'a str,
36    http_client: Option<reqwest::Client>,
37}
38
39impl<'a> ClientBuilder<'a> {
40    pub fn new(api_key: &'a str) -> Self {
41        Self {
42            api_key,
43            fine_tune_api_key: None,
44            base_url: GALADRIEL_API_BASE_URL,
45            http_client: None,
46        }
47    }
48
49    pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
50        self.fine_tune_api_key = Some(fine_tune_api_key);
51        self
52    }
53
54    pub fn base_url(mut self, base_url: &'a str) -> Self {
55        self.base_url = base_url;
56        self
57    }
58
59    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60        self.http_client = Some(client);
61        self
62    }
63
64    pub fn build(self) -> Result<Client, ClientBuilderError> {
65        let http_client = if let Some(http_client) = self.http_client {
66            http_client
67        } else {
68            reqwest::Client::builder().build()?
69        };
70
71        Ok(Client {
72            base_url: self.base_url.to_string(),
73            api_key: self.api_key.to_string(),
74            fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
75            http_client,
76        })
77    }
78}
79#[derive(Clone)]
80pub struct Client {
81    base_url: String,
82    api_key: String,
83    fine_tune_api_key: Option<String>,
84    http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for Client {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("Client")
90            .field("base_url", &self.base_url)
91            .field("http_client", &self.http_client)
92            .field("api_key", &"<REDACTED>")
93            .field("fine_tune_api_key", &"<REDACTED>")
94            .finish()
95    }
96}
97
98impl Client {
99    /// Create a new Galadriel client builder.
100    ///
101    /// # Example
102    /// ```
103    /// use rig::providers::galadriel::{ClientBuilder, self};
104    ///
105    /// // Initialize the Galadriel client
106    /// let galadriel = Client::builder("your-galadriel-api-key")
107    ///    .build()
108    /// ```
109    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
110        ClientBuilder::new(api_key)
111    }
112
113    /// Create a new Galadriel client. For more control, use the `builder` method.
114    ///
115    /// # Panics
116    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
117    pub fn new(api_key: &str) -> Self {
118        Self::builder(api_key)
119            .build()
120            .expect("Galadriel client should build")
121    }
122
123    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
124        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
125        let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
126
127        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
128            client = client.header("Fine-Tune-Authorization", fine_tune_key);
129        }
130
131        client
132    }
133}
134
135impl ProviderClient for Client {
136    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
137    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
138    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
139    fn from_env() -> Self {
140        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
141        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
142        let mut builder = Self::builder(&api_key);
143        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
144            builder = builder.fine_tune_api_key(fine_tune_api_key);
145        }
146        builder.build().expect("Galadriel client should build")
147    }
148
149    fn from_val(input: crate::client::ProviderValue) -> Self {
150        let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
151        else {
152            panic!("Incorrect provider value type")
153        };
154        let mut builder = Self::builder(&api_key);
155        if let Some(fine_tune_key) = fine_tune_key.as_deref() {
156            builder = builder.fine_tune_api_key(fine_tune_key);
157        }
158        builder.build().expect("Galadriel client should build")
159    }
160}
161
162impl CompletionClient for Client {
163    type CompletionModel = CompletionModel;
164
165    /// Create a completion model with the given name.
166    ///
167    /// # Example
168    /// ```
169    /// use rig::providers::galadriel::{Client, self};
170    ///
171    /// // Initialize the Galadriel client
172    /// let galadriel = Client::new("your-galadriel-api-key", None);
173    ///
174    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
175    /// ```
176    fn completion_model(&self, model: &str) -> CompletionModel {
177        CompletionModel::new(self.clone(), model)
178    }
179}
180
181impl_conversion_traits!(
182    AsEmbeddings,
183    AsTranscription,
184    AsImageGeneration,
185    AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189struct ApiErrorResponse {
190    message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195enum ApiResponse<T> {
196    Ok(T),
197    Err(ApiErrorResponse),
198}
199
200#[derive(Clone, Debug, Deserialize, Serialize)]
201pub struct Usage {
202    pub prompt_tokens: usize,
203    pub total_tokens: usize,
204}
205
206impl std::fmt::Display for Usage {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        write!(
209            f,
210            "Prompt tokens: {} Total tokens: {}",
211            self.prompt_tokens, self.total_tokens
212        )
213    }
214}
215
216// ================================================================
217// Galadriel Completion API
218// ================================================================
219/// `o1-preview` completion model
220pub const O1_PREVIEW: &str = "o1-preview";
221/// `o1-preview-2024-09-12` completion model
222pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
223/// `o1-mini completion model
224pub const O1_MINI: &str = "o1-mini";
225/// `o1-mini-2024-09-12` completion model
226pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
227/// `gpt-4o` completion model
228pub const GPT_4O: &str = "gpt-4o";
229/// `gpt-4o-2024-05-13` completion model
230pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
231/// `gpt-4-turbo` completion model
232pub const GPT_4_TURBO: &str = "gpt-4-turbo";
233/// `gpt-4-turbo-2024-04-09` completion model
234pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
235/// `gpt-4-turbo-preview` completion model
236pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
237/// `gpt-4-0125-preview` completion model
238pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
239/// `gpt-4-1106-preview` completion model
240pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
241/// `gpt-4-vision-preview` completion model
242pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
243/// `gpt-4-1106-vision-preview` completion model
244pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
245/// `gpt-4` completion model
246pub const GPT_4: &str = "gpt-4";
247/// `gpt-4-0613` completion model
248pub const GPT_4_0613: &str = "gpt-4-0613";
249/// `gpt-4-32k` completion model
250pub const GPT_4_32K: &str = "gpt-4-32k";
251/// `gpt-4-32k-0613` completion model
252pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
253/// `gpt-3.5-turbo` completion model
254pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
255/// `gpt-3.5-turbo-0125` completion model
256pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
257/// `gpt-3.5-turbo-1106` completion model
258pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
259/// `gpt-3.5-turbo-instruct` completion model
260pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
261
262#[derive(Debug, Deserialize, Serialize)]
263pub struct CompletionResponse {
264    pub id: String,
265    pub object: String,
266    pub created: u64,
267    pub model: String,
268    pub system_fingerprint: Option<String>,
269    pub choices: Vec<Choice>,
270    pub usage: Option<Usage>,
271}
272
273impl From<ApiErrorResponse> for CompletionError {
274    fn from(err: ApiErrorResponse) -> Self {
275        CompletionError::ProviderError(err.message)
276    }
277}
278
279impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
280    type Error = CompletionError;
281
282    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
283        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
284            CompletionError::ResponseError("Response contained no choices".to_owned())
285        })?;
286
287        let mut content = message
288            .content
289            .as_ref()
290            .map(|c| vec![completion::AssistantContent::text(c)])
291            .unwrap_or_default();
292
293        content.extend(message.tool_calls.iter().map(|call| {
294            completion::AssistantContent::tool_call(
295                &call.function.name,
296                &call.function.name,
297                call.function.arguments.clone(),
298            )
299        }));
300
301        let choice = OneOrMany::many(content).map_err(|_| {
302            CompletionError::ResponseError(
303                "Response contained no message or tool call (empty)".to_owned(),
304            )
305        })?;
306        let usage = response
307            .usage
308            .as_ref()
309            .map(|usage| completion::Usage {
310                input_tokens: usage.prompt_tokens as u64,
311                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
312                total_tokens: usage.total_tokens as u64,
313            })
314            .unwrap_or_default();
315
316        Ok(completion::CompletionResponse {
317            choice,
318            usage,
319            raw_response: response,
320        })
321    }
322}
323
324#[derive(Debug, Deserialize, Serialize)]
325pub struct Choice {
326    pub index: usize,
327    pub message: Message,
328    pub logprobs: Option<serde_json::Value>,
329    pub finish_reason: String,
330}
331
332#[derive(Debug, Serialize, Deserialize)]
333pub struct Message {
334    pub role: String,
335    pub content: Option<String>,
336    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
337    pub tool_calls: Vec<openai::ToolCall>,
338}
339
340impl TryFrom<Message> for message::Message {
341    type Error = message::MessageError;
342
343    fn try_from(message: Message) -> Result<Self, Self::Error> {
344        let tool_calls: Vec<message::ToolCall> = message
345            .tool_calls
346            .into_iter()
347            .map(|tool_call| tool_call.into())
348            .collect();
349
350        match message.role.as_str() {
351            "user" => Ok(Self::User {
352                content: OneOrMany::one(
353                    message
354                        .content
355                        .map(|content| message::UserContent::text(&content))
356                        .ok_or_else(|| {
357                            message::MessageError::ConversionError("Empty user message".to_string())
358                        })?,
359                ),
360            }),
361            "assistant" => Ok(Self::Assistant {
362                id: None,
363                content: OneOrMany::many(
364                    tool_calls
365                        .into_iter()
366                        .map(message::AssistantContent::ToolCall)
367                        .chain(
368                            message
369                                .content
370                                .map(|content| message::AssistantContent::text(&content))
371                                .into_iter(),
372                        ),
373                )
374                .map_err(|_| {
375                    message::MessageError::ConversionError("Empty assistant message".to_string())
376                })?,
377            }),
378            _ => Err(message::MessageError::ConversionError(format!(
379                "Unknown role: {}",
380                message.role
381            ))),
382        }
383    }
384}
385
386impl TryFrom<message::Message> for Message {
387    type Error = message::MessageError;
388
389    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
390        match message {
391            message::Message::User { content } => Ok(Self {
392                role: "user".to_string(),
393                content: content.iter().find_map(|c| match c {
394                    message::UserContent::Text(text) => Some(text.text.clone()),
395                    _ => None,
396                }),
397                tool_calls: vec![],
398            }),
399            message::Message::Assistant { content, .. } => {
400                let mut text_content: Option<String> = None;
401                let mut tool_calls = vec![];
402
403                for c in content.iter() {
404                    match c {
405                        message::AssistantContent::Text(text) => {
406                            text_content = Some(
407                                text_content
408                                    .map(|mut existing| {
409                                        existing.push('\n');
410                                        existing.push_str(&text.text);
411                                        existing
412                                    })
413                                    .unwrap_or_else(|| text.text.clone()),
414                            );
415                        }
416                        message::AssistantContent::ToolCall(tool_call) => {
417                            tool_calls.push(tool_call.clone().into());
418                        }
419                        message::AssistantContent::Reasoning(_) => {
420                            return Err(MessageError::ConversionError(
421                                "Galadriel currently doesn't support reasoning.".into(),
422                            ));
423                        }
424                    }
425                }
426
427                Ok(Self {
428                    role: "assistant".to_string(),
429                    content: text_content,
430                    tool_calls,
431                })
432            }
433        }
434    }
435}
436
437#[derive(Clone, Debug, Deserialize, Serialize)]
438pub struct ToolDefinition {
439    pub r#type: String,
440    pub function: completion::ToolDefinition,
441}
442
443impl From<completion::ToolDefinition> for ToolDefinition {
444    fn from(tool: completion::ToolDefinition) -> Self {
445        Self {
446            r#type: "function".into(),
447            function: tool,
448        }
449    }
450}
451
452#[derive(Debug, Deserialize)]
453pub struct Function {
454    pub name: String,
455    pub arguments: String,
456}
457
458#[derive(Clone)]
459pub struct CompletionModel {
460    client: Client,
461    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
462    pub model: String,
463}
464
465impl CompletionModel {
466    pub(crate) fn create_completion_request(
467        &self,
468        completion_request: CompletionRequest,
469    ) -> Result<Value, CompletionError> {
470        // Build up the order of messages (context, chat_history, prompt)
471        let mut partial_history = vec![];
472        if let Some(docs) = completion_request.normalized_documents() {
473            partial_history.push(docs);
474        }
475        partial_history.extend(completion_request.chat_history);
476
477        // Add preamble to chat history (if available)
478        let mut full_history: Vec<Message> = match &completion_request.preamble {
479            Some(preamble) => vec![Message {
480                role: "system".to_string(),
481                content: Some(preamble.to_string()),
482                tool_calls: vec![],
483            }],
484            None => vec![],
485        };
486
487        // Convert and extend the rest of the history
488        full_history.extend(
489            partial_history
490                .into_iter()
491                .map(message::Message::try_into)
492                .collect::<Result<Vec<Message>, _>>()?,
493        );
494
495        let request = if completion_request.tools.is_empty() {
496            json!({
497                "model": self.model,
498                "messages": full_history,
499                "temperature": completion_request.temperature,
500            })
501        } else {
502            json!({
503                "model": self.model,
504                "messages": full_history,
505                "temperature": completion_request.temperature,
506                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
507                "tool_choice": "auto",
508            })
509        };
510
511        let request = if let Some(params) = completion_request.additional_params {
512            json_utils::merge(request, params)
513        } else {
514            request
515        };
516
517        Ok(request)
518    }
519}
520
521impl CompletionModel {
522    pub fn new(client: Client, model: &str) -> Self {
523        Self {
524            client,
525            model: model.to_string(),
526        }
527    }
528}
529
530impl completion::CompletionModel for CompletionModel {
531    type Response = CompletionResponse;
532    type StreamingResponse = openai::StreamingCompletionResponse;
533
534    #[cfg_attr(feature = "worker", worker::send)]
535    async fn completion(
536        &self,
537        completion_request: CompletionRequest,
538    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
539        let request = self.create_completion_request(completion_request)?;
540
541        let response = self
542            .client
543            .post("/chat/completions")
544            .json(&request)
545            .send()
546            .await?;
547
548        if response.status().is_success() {
549            let t = response.text().await?;
550            tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
551
552            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
553                ApiResponse::Ok(response) => {
554                    tracing::info!(target: "rig",
555                        "Galadriel completion token usage: {:?}",
556                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
557                    );
558                    response.try_into()
559                }
560                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
561            }
562        } else {
563            Err(CompletionError::ProviderError(response.text().await?))
564        }
565    }
566
567    #[cfg_attr(feature = "worker", worker::send)]
568    async fn stream(
569        &self,
570        request: CompletionRequest,
571    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
572        let mut request = self.create_completion_request(request)?;
573
574        request = merge(
575            request,
576            json!({"stream": true, "stream_options": {"include_usage": true}}),
577        );
578
579        let builder = self.client.post("/chat/completions").json(&request);
580
581        send_compatible_streaming_request(builder).await
582    }
583}