rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::message::MessageError;
18use crate::providers::openai::send_compatible_streaming_request;
19use crate::streaming::StreamingCompletionResponse;
20use crate::{
21    OneOrMany,
22    completion::{self, CompletionError, CompletionRequest},
23    impl_conversion_traits, json_utils, message,
24};
25use bytes::Bytes;
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28use tracing::{Instrument, info_span};
29
30// ================================================================
31// Main Galadriel Client
32// ================================================================
33const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35pub struct ClientBuilder<'a, T = reqwest::Client> {
36    api_key: &'a str,
37    fine_tune_api_key: Option<&'a str>,
38    base_url: &'a str,
39    http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44    T: Default,
45{
46    pub fn new(api_key: &'a str) -> Self {
47        Self {
48            api_key,
49            fine_tune_api_key: None,
50            base_url: GALADRIEL_API_BASE_URL,
51            http_client: Default::default(),
52        }
53    }
54}
55
56impl<'a, T> ClientBuilder<'a, T> {
57    pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
58        self.fine_tune_api_key = Some(fine_tune_api_key);
59        self
60    }
61
62    pub fn base_url(mut self, base_url: &'a str) -> Self {
63        self.base_url = base_url;
64        self
65    }
66
67    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
68        ClientBuilder {
69            api_key: self.api_key,
70            fine_tune_api_key: self.fine_tune_api_key,
71            base_url: self.base_url,
72            http_client,
73        }
74    }
75
76    pub fn build(self) -> Client<T> {
77        Client {
78            base_url: self.base_url.to_string(),
79            api_key: self.api_key.to_string(),
80            fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
81            http_client: self.http_client,
82        }
83    }
84}
85#[derive(Clone)]
86pub struct Client<T = reqwest::Client> {
87    base_url: String,
88    api_key: String,
89    fine_tune_api_key: Option<String>,
90    http_client: T,
91}
92
93impl<T> std::fmt::Debug for Client<T>
94where
95    T: std::fmt::Debug,
96{
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("Client")
99            .field("base_url", &self.base_url)
100            .field("http_client", &self.http_client)
101            .field("api_key", &"<REDACTED>")
102            .field("fine_tune_api_key", &"<REDACTED>")
103            .finish()
104    }
105}
106
107impl<T> Client<T>
108where
109    T: Default,
110{
111    /// Create a new Galadriel client builder.
112    ///
113    /// # Example
114    /// ```
115    /// use rig::providers::galadriel::{ClientBuilder, self};
116    ///
117    /// // Initialize the Galadriel client
118    /// let galadriel = Client::builder("your-galadriel-api-key")
119    ///    .build()
120    /// ```
121    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
122        ClientBuilder::new(api_key)
123    }
124
125    /// Create a new Galadriel client. For more control, use the `builder` method.
126    ///
127    /// # Panics
128    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
129    pub fn new(api_key: &str) -> Self {
130        Self::builder(api_key).build()
131    }
132}
133
134impl<T> Client<T>
135where
136    T: HttpClientExt,
137{
138    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
139        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
140
141        let mut req = http_client::Request::post(url);
142
143        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
144            req = req.header("Fine-Tune-Authorization", fine_tune_key);
145        }
146
147        http_client::with_bearer_auth(req, &self.api_key)
148    }
149
150    async fn send<U, R>(
151        &self,
152        req: http_client::Request<U>,
153    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
154    where
155        U: Into<Bytes> + Send,
156        R: From<Bytes> + Send + 'static,
157    {
158        self.http_client.send(req).await
159    }
160}
161
162impl Client<reqwest::Client> {
163    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
164        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
165        let mut req = self.http_client.post(url).bearer_auth(&self.api_key);
166
167        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
168            req = req.header("Fine-Tune-Authorization", fine_tune_key)
169        }
170
171        req
172    }
173}
174
175impl ProviderClient for Client<reqwest::Client> {
176    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
177    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
178    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
179    fn from_env() -> Self {
180        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
181        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
182        let mut builder = Self::builder(&api_key);
183        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
184            builder = builder.fine_tune_api_key(fine_tune_api_key);
185        }
186        builder.build()
187    }
188
189    fn from_val(input: crate::client::ProviderValue) -> Self {
190        let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
191        else {
192            panic!("Incorrect provider value type")
193        };
194        let mut builder = Self::builder(&api_key);
195        if let Some(fine_tune_key) = fine_tune_key.as_deref() {
196            builder = builder.fine_tune_api_key(fine_tune_key);
197        }
198        builder.build()
199    }
200}
201
202impl CompletionClient for Client<reqwest::Client> {
203    type CompletionModel = CompletionModel<reqwest::Client>;
204
205    /// Create a completion model with the given name.
206    ///
207    /// # Example
208    /// ```
209    /// use rig::providers::galadriel::{Client, self};
210    ///
211    /// // Initialize the Galadriel client
212    /// let galadriel = Client::new("your-galadriel-api-key", None);
213    ///
214    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
215    /// ```
216    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
217        CompletionModel::new(self.clone(), model)
218    }
219}
220
221impl VerifyClient for Client<reqwest::Client> {
222    #[cfg_attr(feature = "worker", worker::send)]
223    async fn verify(&self) -> Result<(), VerifyError> {
224        // Could not find an API endpoint to verify the API key
225        Ok(())
226    }
227}
228
229impl_conversion_traits!(
230    AsEmbeddings,
231    AsTranscription,
232    AsImageGeneration,
233    AsAudioGeneration for Client<T>
234);
235
236#[derive(Debug, Deserialize)]
237struct ApiErrorResponse {
238    message: String,
239}
240
241#[derive(Debug, Deserialize)]
242#[serde(untagged)]
243enum ApiResponse<T> {
244    Ok(T),
245    Err(ApiErrorResponse),
246}
247
248#[derive(Clone, Debug, Deserialize, Serialize)]
249pub struct Usage {
250    pub prompt_tokens: usize,
251    pub total_tokens: usize,
252}
253
254impl std::fmt::Display for Usage {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        write!(
257            f,
258            "Prompt tokens: {} Total tokens: {}",
259            self.prompt_tokens, self.total_tokens
260        )
261    }
262}
263
264// ================================================================
265// Galadriel Completion API
266// ================================================================
267/// `o1-preview` completion model
268pub const O1_PREVIEW: &str = "o1-preview";
269/// `o1-preview-2024-09-12` completion model
270pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
271/// `o1-mini completion model
272pub const O1_MINI: &str = "o1-mini";
273/// `o1-mini-2024-09-12` completion model
274pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
275/// `gpt-4o` completion model
276pub const GPT_4O: &str = "gpt-4o";
277/// `gpt-4o-2024-05-13` completion model
278pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
279/// `gpt-4-turbo` completion model
280pub const GPT_4_TURBO: &str = "gpt-4-turbo";
281/// `gpt-4-turbo-2024-04-09` completion model
282pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
283/// `gpt-4-turbo-preview` completion model
284pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
285/// `gpt-4-0125-preview` completion model
286pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
287/// `gpt-4-1106-preview` completion model
288pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
289/// `gpt-4-vision-preview` completion model
290pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
291/// `gpt-4-1106-vision-preview` completion model
292pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
293/// `gpt-4` completion model
294pub const GPT_4: &str = "gpt-4";
295/// `gpt-4-0613` completion model
296pub const GPT_4_0613: &str = "gpt-4-0613";
297/// `gpt-4-32k` completion model
298pub const GPT_4_32K: &str = "gpt-4-32k";
299/// `gpt-4-32k-0613` completion model
300pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
301/// `gpt-3.5-turbo` completion model
302pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
303/// `gpt-3.5-turbo-0125` completion model
304pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
305/// `gpt-3.5-turbo-1106` completion model
306pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
307/// `gpt-3.5-turbo-instruct` completion model
308pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
309
310#[derive(Debug, Deserialize, Serialize)]
311pub struct CompletionResponse {
312    pub id: String,
313    pub object: String,
314    pub created: u64,
315    pub model: String,
316    pub system_fingerprint: Option<String>,
317    pub choices: Vec<Choice>,
318    pub usage: Option<Usage>,
319}
320
321impl From<ApiErrorResponse> for CompletionError {
322    fn from(err: ApiErrorResponse) -> Self {
323        CompletionError::ProviderError(err.message)
324    }
325}
326
327impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
328    type Error = CompletionError;
329
330    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
331        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
332            CompletionError::ResponseError("Response contained no choices".to_owned())
333        })?;
334
335        let mut content = message
336            .content
337            .as_ref()
338            .map(|c| vec![completion::AssistantContent::text(c)])
339            .unwrap_or_default();
340
341        content.extend(message.tool_calls.iter().map(|call| {
342            completion::AssistantContent::tool_call(
343                &call.function.name,
344                &call.function.name,
345                call.function.arguments.clone(),
346            )
347        }));
348
349        let choice = OneOrMany::many(content).map_err(|_| {
350            CompletionError::ResponseError(
351                "Response contained no message or tool call (empty)".to_owned(),
352            )
353        })?;
354        let usage = response
355            .usage
356            .as_ref()
357            .map(|usage| completion::Usage {
358                input_tokens: usage.prompt_tokens as u64,
359                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
360                total_tokens: usage.total_tokens as u64,
361            })
362            .unwrap_or_default();
363
364        Ok(completion::CompletionResponse {
365            choice,
366            usage,
367            raw_response: response,
368        })
369    }
370}
371
372#[derive(Debug, Deserialize, Serialize)]
373pub struct Choice {
374    pub index: usize,
375    pub message: Message,
376    pub logprobs: Option<serde_json::Value>,
377    pub finish_reason: String,
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381pub struct Message {
382    pub role: String,
383    pub content: Option<String>,
384    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
385    pub tool_calls: Vec<openai::ToolCall>,
386}
387
388impl TryFrom<Message> for message::Message {
389    type Error = message::MessageError;
390
391    fn try_from(message: Message) -> Result<Self, Self::Error> {
392        let tool_calls: Vec<message::ToolCall> = message
393            .tool_calls
394            .into_iter()
395            .map(|tool_call| tool_call.into())
396            .collect();
397
398        match message.role.as_str() {
399            "user" => Ok(Self::User {
400                content: OneOrMany::one(
401                    message
402                        .content
403                        .map(|content| message::UserContent::text(&content))
404                        .ok_or_else(|| {
405                            message::MessageError::ConversionError("Empty user message".to_string())
406                        })?,
407                ),
408            }),
409            "assistant" => Ok(Self::Assistant {
410                id: None,
411                content: OneOrMany::many(
412                    tool_calls
413                        .into_iter()
414                        .map(message::AssistantContent::ToolCall)
415                        .chain(
416                            message
417                                .content
418                                .map(|content| message::AssistantContent::text(&content))
419                                .into_iter(),
420                        ),
421                )
422                .map_err(|_| {
423                    message::MessageError::ConversionError("Empty assistant message".to_string())
424                })?,
425            }),
426            _ => Err(message::MessageError::ConversionError(format!(
427                "Unknown role: {}",
428                message.role
429            ))),
430        }
431    }
432}
433
434impl TryFrom<message::Message> for Message {
435    type Error = message::MessageError;
436
437    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
438        match message {
439            message::Message::User { content } => Ok(Self {
440                role: "user".to_string(),
441                content: content.iter().find_map(|c| match c {
442                    message::UserContent::Text(text) => Some(text.text.clone()),
443                    _ => None,
444                }),
445                tool_calls: vec![],
446            }),
447            message::Message::Assistant { content, .. } => {
448                let mut text_content: Option<String> = None;
449                let mut tool_calls = vec![];
450
451                for c in content.iter() {
452                    match c {
453                        message::AssistantContent::Text(text) => {
454                            text_content = Some(
455                                text_content
456                                    .map(|mut existing| {
457                                        existing.push('\n');
458                                        existing.push_str(&text.text);
459                                        existing
460                                    })
461                                    .unwrap_or_else(|| text.text.clone()),
462                            );
463                        }
464                        message::AssistantContent::ToolCall(tool_call) => {
465                            tool_calls.push(tool_call.clone().into());
466                        }
467                        message::AssistantContent::Reasoning(_) => {
468                            return Err(MessageError::ConversionError(
469                                "Galadriel currently doesn't support reasoning.".into(),
470                            ));
471                        }
472                    }
473                }
474
475                Ok(Self {
476                    role: "assistant".to_string(),
477                    content: text_content,
478                    tool_calls,
479                })
480            }
481        }
482    }
483}
484
485#[derive(Clone, Debug, Deserialize, Serialize)]
486pub struct ToolDefinition {
487    pub r#type: String,
488    pub function: completion::ToolDefinition,
489}
490
491impl From<completion::ToolDefinition> for ToolDefinition {
492    fn from(tool: completion::ToolDefinition) -> Self {
493        Self {
494            r#type: "function".into(),
495            function: tool,
496        }
497    }
498}
499
500#[derive(Debug, Deserialize)]
501pub struct Function {
502    pub name: String,
503    pub arguments: String,
504}
505
506#[derive(Clone)]
507pub struct CompletionModel<T = reqwest::Client> {
508    client: Client<T>,
509    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
510    pub model: String,
511}
512
513impl<T> CompletionModel<T>
514where
515    T: HttpClientExt,
516{
517    pub fn new(client: Client<T>, model: &str) -> Self {
518        Self {
519            client,
520            model: model.to_string(),
521        }
522    }
523
524    pub(crate) fn create_completion_request(
525        &self,
526        completion_request: CompletionRequest,
527    ) -> Result<Value, CompletionError> {
528        // Build up the order of messages (context, chat_history, prompt)
529        let mut partial_history = vec![];
530        if let Some(docs) = completion_request.normalized_documents() {
531            partial_history.push(docs);
532        }
533        partial_history.extend(completion_request.chat_history);
534
535        // Add preamble to chat history (if available)
536        let mut full_history: Vec<Message> = match &completion_request.preamble {
537            Some(preamble) => vec![Message {
538                role: "system".to_string(),
539                content: Some(preamble.to_string()),
540                tool_calls: vec![],
541            }],
542            None => vec![],
543        };
544
545        // Convert and extend the rest of the history
546        full_history.extend(
547            partial_history
548                .into_iter()
549                .map(message::Message::try_into)
550                .collect::<Result<Vec<Message>, _>>()?,
551        );
552
553        let tool_choice = completion_request
554            .tool_choice
555            .clone()
556            .map(crate::providers::openai::completion::ToolChoice::try_from)
557            .transpose()?;
558
559        let request = if completion_request.tools.is_empty() {
560            json!({
561                "model": self.model,
562                "messages": full_history,
563                "temperature": completion_request.temperature,
564            })
565        } else {
566            json!({
567                "model": self.model,
568                "messages": full_history,
569                "temperature": completion_request.temperature,
570                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
571                "tool_choice": tool_choice,
572            })
573        };
574
575        let request = if let Some(params) = completion_request.additional_params {
576            json_utils::merge(request, params)
577        } else {
578            request
579        };
580
581        Ok(request)
582    }
583}
584
585impl completion::CompletionModel for CompletionModel<reqwest::Client> {
586    type Response = CompletionResponse;
587    type StreamingResponse = openai::StreamingCompletionResponse;
588
589    #[cfg_attr(feature = "worker", worker::send)]
590    async fn completion(
591        &self,
592        completion_request: CompletionRequest,
593    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
594        let preamble = completion_request.preamble.clone();
595        let request = self.create_completion_request(completion_request)?;
596        let body = serde_json::to_vec(&request)?;
597
598        let req = self
599            .client
600            .post("/chat/completions")?
601            .header("Content-Type", "application/json")
602            .body(body)
603            .map_err(http_client::Error::from)?;
604
605        let span = if tracing::Span::current().is_disabled() {
606            info_span!(
607                target: "rig::completions",
608                "chat",
609                gen_ai.operation.name = "chat",
610                gen_ai.provider.name = "galadriel",
611                gen_ai.request.model = self.model,
612                gen_ai.system_instructions = preamble,
613                gen_ai.response.id = tracing::field::Empty,
614                gen_ai.response.model = tracing::field::Empty,
615                gen_ai.usage.output_tokens = tracing::field::Empty,
616                gen_ai.usage.input_tokens = tracing::field::Empty,
617                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
618                gen_ai.output.messages = tracing::field::Empty,
619            )
620        } else {
621            tracing::Span::current()
622        };
623
624        async move {
625            let response = self.client.send(req).await?;
626
627            if response.status().is_success() {
628                let t = http_client::text(response).await?;
629                tracing::debug!(target: "rig::completions", "Galadriel completion response: {t}");
630
631                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
632                    ApiResponse::Ok(response) => {
633                        let span = tracing::Span::current();
634                        span.record("gen_ai.response.id", response.id.clone());
635                        span.record("gen_ai.response.model_name", response.model.clone());
636                        span.record(
637                            "gen_ai.output.messages",
638                            serde_json::to_string(&response.choices).unwrap(),
639                        );
640                        if let Some(ref usage) = response.usage {
641                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
642                            span.record(
643                                "gen_ai.usage.output_tokens",
644                                usage.total_tokens - usage.prompt_tokens,
645                            );
646                        }
647                        response.try_into()
648                    }
649                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
650                }
651            } else {
652                let text = http_client::text(response).await?;
653
654                Err(CompletionError::ProviderError(text))
655            }
656        }
657        .instrument(span)
658        .await
659    }
660
661    #[cfg_attr(feature = "worker", worker::send)]
662    async fn stream(
663        &self,
664        request: CompletionRequest,
665    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
666        let preamble = request.preamble.clone();
667        let mut request = self.create_completion_request(request)?;
668
669        request = merge(
670            request,
671            json!({"stream": true, "stream_options": {"include_usage": true}}),
672        );
673
674        let builder = self
675            .client
676            .reqwest_post("/chat/completions")
677            .header("Content-Type", "application/json")
678            .json(&request);
679
680        let span = if tracing::Span::current().is_disabled() {
681            info_span!(
682                target: "rig::completions",
683                "chat_streaming",
684                gen_ai.operation.name = "chat_streaming",
685                gen_ai.provider.name = "galadriel",
686                gen_ai.request.model = self.model,
687                gen_ai.system_instructions = preamble,
688                gen_ai.response.id = tracing::field::Empty,
689                gen_ai.response.model = tracing::field::Empty,
690                gen_ai.usage.output_tokens = tracing::field::Empty,
691                gen_ai.usage.input_tokens = tracing::field::Empty,
692                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
693                gen_ai.output.messages = tracing::field::Empty,
694            )
695        } else {
696            tracing::Span::current()
697        };
698
699        send_compatible_streaming_request(builder)
700            .instrument(span)
701            .await
702    }
703}