rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::message::MessageError;
18use crate::providers::openai::send_compatible_streaming_request;
19use crate::streaming::StreamingCompletionResponse;
20use crate::{
21    OneOrMany,
22    completion::{self, CompletionError, CompletionRequest},
23    impl_conversion_traits, json_utils, message,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tracing::{Instrument, info_span};
28
29// ================================================================
30// Main Galadriel Client
31// ================================================================
32const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
33
34pub struct ClientBuilder<'a, T = reqwest::Client> {
35    api_key: &'a str,
36    fine_tune_api_key: Option<&'a str>,
37    base_url: &'a str,
38    http_client: T,
39}
40
41impl<'a, T> ClientBuilder<'a, T>
42where
43    T: Default,
44{
45    pub fn new(api_key: &'a str) -> Self {
46        Self {
47            api_key,
48            fine_tune_api_key: None,
49            base_url: GALADRIEL_API_BASE_URL,
50            http_client: Default::default(),
51        }
52    }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56    pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
57        self.fine_tune_api_key = Some(fine_tune_api_key);
58        self
59    }
60
61    pub fn base_url(mut self, base_url: &'a str) -> Self {
62        self.base_url = base_url;
63        self
64    }
65
66    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
67        ClientBuilder {
68            api_key: self.api_key,
69            fine_tune_api_key: self.fine_tune_api_key,
70            base_url: self.base_url,
71            http_client,
72        }
73    }
74
75    pub fn build(self) -> Client<T> {
76        Client {
77            base_url: self.base_url.to_string(),
78            api_key: self.api_key.to_string(),
79            fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
80            http_client: self.http_client,
81        }
82    }
83}
84#[derive(Clone)]
85pub struct Client<T = reqwest::Client> {
86    base_url: String,
87    api_key: String,
88    fine_tune_api_key: Option<String>,
89    http_client: T,
90}
91
92impl<T> std::fmt::Debug for Client<T>
93where
94    T: std::fmt::Debug,
95{
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_struct("Client")
98            .field("base_url", &self.base_url)
99            .field("http_client", &self.http_client)
100            .field("api_key", &"<REDACTED>")
101            .field("fine_tune_api_key", &"<REDACTED>")
102            .finish()
103    }
104}
105
106impl<T> Client<T>
107where
108    T: HttpClientExt,
109{
110    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113        let mut req = http_client::Request::post(url);
114
115        if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
116            req = req.header("Fine-Tune-Authorization", fine_tune_key);
117        }
118
119        http_client::with_bearer_auth(req, &self.api_key)
120    }
121}
122
123impl Client<reqwest::Client> {
124    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
125        ClientBuilder::new(api_key)
126    }
127
128    pub fn new(api_key: &str) -> Self {
129        ClientBuilder::new(api_key).build()
130    }
131
132    pub fn from_env() -> Self {
133        <Self as ProviderClient>::from_env()
134    }
135}
136
137impl<T> ProviderClient for Client<T>
138where
139    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
140{
141    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
142    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
143    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
144    fn from_env() -> Self {
145        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
146        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
147        let mut builder = ClientBuilder::<T>::new(&api_key);
148        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
149            builder = builder.fine_tune_api_key(fine_tune_api_key);
150        }
151        builder.build()
152    }
153
154    fn from_val(input: crate::client::ProviderValue) -> Self {
155        let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
156        else {
157            panic!("Incorrect provider value type")
158        };
159        let mut builder = ClientBuilder::<T>::new(&api_key);
160        if let Some(fine_tune_key) = fine_tune_key.as_deref() {
161            builder = builder.fine_tune_api_key(fine_tune_key);
162        }
163        builder.build()
164    }
165}
166
167impl<T> CompletionClient for Client<T>
168where
169    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
170{
171    type CompletionModel = CompletionModel<T>;
172
173    /// Create a completion model with the given name.
174    ///
175    /// # Example
176    /// ```
177    /// use rig::providers::galadriel::{Client, self};
178    ///
179    /// // Initialize the Galadriel client
180    /// let galadriel = Client::new("your-galadriel-api-key", None);
181    ///
182    /// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
183    /// ```
184    fn completion_model(&self, model: &str) -> Self::CompletionModel {
185        CompletionModel::new(self.clone(), model)
186    }
187}
188
189impl<T> VerifyClient for Client<T>
190where
191    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
192{
193    #[cfg_attr(feature = "worker", worker::send)]
194    async fn verify(&self) -> Result<(), VerifyError> {
195        // Could not find an API endpoint to verify the API key
196        Ok(())
197    }
198}
199
200impl_conversion_traits!(
201    AsEmbeddings,
202    AsTranscription,
203    AsImageGeneration,
204    AsAudioGeneration for Client<T>
205);
206
207#[derive(Debug, Deserialize)]
208struct ApiErrorResponse {
209    message: String,
210}
211
212#[derive(Debug, Deserialize)]
213#[serde(untagged)]
214enum ApiResponse<T> {
215    Ok(T),
216    Err(ApiErrorResponse),
217}
218
219#[derive(Clone, Debug, Deserialize, Serialize)]
220pub struct Usage {
221    pub prompt_tokens: usize,
222    pub total_tokens: usize,
223}
224
225impl std::fmt::Display for Usage {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        write!(
228            f,
229            "Prompt tokens: {} Total tokens: {}",
230            self.prompt_tokens, self.total_tokens
231        )
232    }
233}
234
235// ================================================================
236// Galadriel Completion API
237// ================================================================
238/// `o1-preview` completion model
239pub const O1_PREVIEW: &str = "o1-preview";
240/// `o1-preview-2024-09-12` completion model
241pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
242/// `o1-mini completion model
243pub const O1_MINI: &str = "o1-mini";
244/// `o1-mini-2024-09-12` completion model
245pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
246/// `gpt-4o` completion model
247pub const GPT_4O: &str = "gpt-4o";
248/// `gpt-4o-2024-05-13` completion model
249pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
250/// `gpt-4-turbo` completion model
251pub const GPT_4_TURBO: &str = "gpt-4-turbo";
252/// `gpt-4-turbo-2024-04-09` completion model
253pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
254/// `gpt-4-turbo-preview` completion model
255pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
256/// `gpt-4-0125-preview` completion model
257pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
258/// `gpt-4-1106-preview` completion model
259pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
260/// `gpt-4-vision-preview` completion model
261pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
262/// `gpt-4-1106-vision-preview` completion model
263pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
264/// `gpt-4` completion model
265pub const GPT_4: &str = "gpt-4";
266/// `gpt-4-0613` completion model
267pub const GPT_4_0613: &str = "gpt-4-0613";
268/// `gpt-4-32k` completion model
269pub const GPT_4_32K: &str = "gpt-4-32k";
270/// `gpt-4-32k-0613` completion model
271pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
272/// `gpt-3.5-turbo` completion model
273pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
274/// `gpt-3.5-turbo-0125` completion model
275pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
276/// `gpt-3.5-turbo-1106` completion model
277pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
278/// `gpt-3.5-turbo-instruct` completion model
279pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
280
281#[derive(Debug, Deserialize, Serialize)]
282pub struct CompletionResponse {
283    pub id: String,
284    pub object: String,
285    pub created: u64,
286    pub model: String,
287    pub system_fingerprint: Option<String>,
288    pub choices: Vec<Choice>,
289    pub usage: Option<Usage>,
290}
291
292impl From<ApiErrorResponse> for CompletionError {
293    fn from(err: ApiErrorResponse) -> Self {
294        CompletionError::ProviderError(err.message)
295    }
296}
297
298impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
299    type Error = CompletionError;
300
301    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
302        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
303            CompletionError::ResponseError("Response contained no choices".to_owned())
304        })?;
305
306        let mut content = message
307            .content
308            .as_ref()
309            .map(|c| vec![completion::AssistantContent::text(c)])
310            .unwrap_or_default();
311
312        content.extend(message.tool_calls.iter().map(|call| {
313            completion::AssistantContent::tool_call(
314                &call.function.name,
315                &call.function.name,
316                call.function.arguments.clone(),
317            )
318        }));
319
320        let choice = OneOrMany::many(content).map_err(|_| {
321            CompletionError::ResponseError(
322                "Response contained no message or tool call (empty)".to_owned(),
323            )
324        })?;
325        let usage = response
326            .usage
327            .as_ref()
328            .map(|usage| completion::Usage {
329                input_tokens: usage.prompt_tokens as u64,
330                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
331                total_tokens: usage.total_tokens as u64,
332            })
333            .unwrap_or_default();
334
335        Ok(completion::CompletionResponse {
336            choice,
337            usage,
338            raw_response: response,
339        })
340    }
341}
342
343#[derive(Debug, Deserialize, Serialize)]
344pub struct Choice {
345    pub index: usize,
346    pub message: Message,
347    pub logprobs: Option<serde_json::Value>,
348    pub finish_reason: String,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352pub struct Message {
353    pub role: String,
354    pub content: Option<String>,
355    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
356    pub tool_calls: Vec<openai::ToolCall>,
357}
358
359impl TryFrom<Message> for message::Message {
360    type Error = message::MessageError;
361
362    fn try_from(message: Message) -> Result<Self, Self::Error> {
363        let tool_calls: Vec<message::ToolCall> = message
364            .tool_calls
365            .into_iter()
366            .map(|tool_call| tool_call.into())
367            .collect();
368
369        match message.role.as_str() {
370            "user" => Ok(Self::User {
371                content: OneOrMany::one(
372                    message
373                        .content
374                        .map(|content| message::UserContent::text(&content))
375                        .ok_or_else(|| {
376                            message::MessageError::ConversionError("Empty user message".to_string())
377                        })?,
378                ),
379            }),
380            "assistant" => Ok(Self::Assistant {
381                id: None,
382                content: OneOrMany::many(
383                    tool_calls
384                        .into_iter()
385                        .map(message::AssistantContent::ToolCall)
386                        .chain(
387                            message
388                                .content
389                                .map(|content| message::AssistantContent::text(&content))
390                                .into_iter(),
391                        ),
392                )
393                .map_err(|_| {
394                    message::MessageError::ConversionError("Empty assistant message".to_string())
395                })?,
396            }),
397            _ => Err(message::MessageError::ConversionError(format!(
398                "Unknown role: {}",
399                message.role
400            ))),
401        }
402    }
403}
404
405impl TryFrom<message::Message> for Message {
406    type Error = message::MessageError;
407
408    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
409        match message {
410            message::Message::User { content } => Ok(Self {
411                role: "user".to_string(),
412                content: content.iter().find_map(|c| match c {
413                    message::UserContent::Text(text) => Some(text.text.clone()),
414                    _ => None,
415                }),
416                tool_calls: vec![],
417            }),
418            message::Message::Assistant { content, .. } => {
419                let mut text_content: Option<String> = None;
420                let mut tool_calls = vec![];
421
422                for c in content.iter() {
423                    match c {
424                        message::AssistantContent::Text(text) => {
425                            text_content = Some(
426                                text_content
427                                    .map(|mut existing| {
428                                        existing.push('\n');
429                                        existing.push_str(&text.text);
430                                        existing
431                                    })
432                                    .unwrap_or_else(|| text.text.clone()),
433                            );
434                        }
435                        message::AssistantContent::ToolCall(tool_call) => {
436                            tool_calls.push(tool_call.clone().into());
437                        }
438                        message::AssistantContent::Reasoning(_) => {
439                            return Err(MessageError::ConversionError(
440                                "Galadriel currently doesn't support reasoning.".into(),
441                            ));
442                        }
443                    }
444                }
445
446                Ok(Self {
447                    role: "assistant".to_string(),
448                    content: text_content,
449                    tool_calls,
450                })
451            }
452        }
453    }
454}
455
456#[derive(Clone, Debug, Deserialize, Serialize)]
457pub struct ToolDefinition {
458    pub r#type: String,
459    pub function: completion::ToolDefinition,
460}
461
462impl From<completion::ToolDefinition> for ToolDefinition {
463    fn from(tool: completion::ToolDefinition) -> Self {
464        Self {
465            r#type: "function".into(),
466            function: tool,
467        }
468    }
469}
470
471#[derive(Debug, Deserialize)]
472pub struct Function {
473    pub name: String,
474    pub arguments: String,
475}
476
477#[derive(Clone)]
478pub struct CompletionModel<T = reqwest::Client> {
479    client: Client<T>,
480    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
481    pub model: String,
482}
483
484impl<T> CompletionModel<T>
485where
486    T: HttpClientExt,
487{
488    pub fn new(client: Client<T>, model: &str) -> Self {
489        Self {
490            client,
491            model: model.to_string(),
492        }
493    }
494
495    pub(crate) fn create_completion_request(
496        &self,
497        completion_request: CompletionRequest,
498    ) -> Result<Value, CompletionError> {
499        // Build up the order of messages (context, chat_history, prompt)
500        let mut partial_history = vec![];
501        if let Some(docs) = completion_request.normalized_documents() {
502            partial_history.push(docs);
503        }
504        partial_history.extend(completion_request.chat_history);
505
506        // Add preamble to chat history (if available)
507        let mut full_history: Vec<Message> = match &completion_request.preamble {
508            Some(preamble) => vec![Message {
509                role: "system".to_string(),
510                content: Some(preamble.to_string()),
511                tool_calls: vec![],
512            }],
513            None => vec![],
514        };
515
516        // Convert and extend the rest of the history
517        full_history.extend(
518            partial_history
519                .into_iter()
520                .map(message::Message::try_into)
521                .collect::<Result<Vec<Message>, _>>()?,
522        );
523
524        let tool_choice = completion_request
525            .tool_choice
526            .clone()
527            .map(crate::providers::openai::completion::ToolChoice::try_from)
528            .transpose()?;
529
530        let request = if completion_request.tools.is_empty() {
531            json!({
532                "model": self.model,
533                "messages": full_history,
534                "temperature": completion_request.temperature,
535            })
536        } else {
537            json!({
538                "model": self.model,
539                "messages": full_history,
540                "temperature": completion_request.temperature,
541                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
542                "tool_choice": tool_choice,
543            })
544        };
545
546        let request = if let Some(params) = completion_request.additional_params {
547            json_utils::merge(request, params)
548        } else {
549            request
550        };
551
552        Ok(request)
553    }
554}
555
556impl<T> completion::CompletionModel for CompletionModel<T>
557where
558    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
559{
560    type Response = CompletionResponse;
561    type StreamingResponse = openai::StreamingCompletionResponse;
562
563    #[cfg_attr(feature = "worker", worker::send)]
564    async fn completion(
565        &self,
566        completion_request: CompletionRequest,
567    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
568        let preamble = completion_request.preamble.clone();
569        let request = self.create_completion_request(completion_request)?;
570        let body = serde_json::to_vec(&request)?;
571
572        let req = self
573            .client
574            .post("/chat/completions")?
575            .header("Content-Type", "application/json")
576            .body(body)
577            .map_err(http_client::Error::from)?;
578
579        let span = if tracing::Span::current().is_disabled() {
580            info_span!(
581                target: "rig::completions",
582                "chat",
583                gen_ai.operation.name = "chat",
584                gen_ai.provider.name = "galadriel",
585                gen_ai.request.model = self.model,
586                gen_ai.system_instructions = preamble,
587                gen_ai.response.id = tracing::field::Empty,
588                gen_ai.response.model = tracing::field::Empty,
589                gen_ai.usage.output_tokens = tracing::field::Empty,
590                gen_ai.usage.input_tokens = tracing::field::Empty,
591                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
592                gen_ai.output.messages = tracing::field::Empty,
593            )
594        } else {
595            tracing::Span::current()
596        };
597
598        async move {
599            let response = self.client.http_client.send(req).await?;
600
601            if response.status().is_success() {
602                let t = http_client::text(response).await?;
603                tracing::debug!(target: "rig::completions", "Galadriel completion response: {t}");
604
605                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
606                    ApiResponse::Ok(response) => {
607                        let span = tracing::Span::current();
608                        span.record("gen_ai.response.id", response.id.clone());
609                        span.record("gen_ai.response.model_name", response.model.clone());
610                        span.record(
611                            "gen_ai.output.messages",
612                            serde_json::to_string(&response.choices).unwrap(),
613                        );
614                        if let Some(ref usage) = response.usage {
615                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
616                            span.record(
617                                "gen_ai.usage.output_tokens",
618                                usage.total_tokens - usage.prompt_tokens,
619                            );
620                        }
621                        response.try_into()
622                    }
623                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
624                }
625            } else {
626                let text = http_client::text(response).await?;
627
628                Err(CompletionError::ProviderError(text))
629            }
630        }
631        .instrument(span)
632        .await
633    }
634
635    #[cfg_attr(feature = "worker", worker::send)]
636    async fn stream(
637        &self,
638        request: CompletionRequest,
639    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
640        let preamble = request.preamble.clone();
641        let mut request = self.create_completion_request(request)?;
642
643        request = merge(
644            request,
645            json!({"stream": true, "stream_options": {"include_usage": true}}),
646        );
647
648        let body = serde_json::to_vec(&request)?;
649
650        let req = self
651            .client
652            .post("/chat/completions")?
653            .header("Content-Type", "application/json")
654            .body(body)
655            .map_err(http_client::Error::from)?;
656
657        let span = if tracing::Span::current().is_disabled() {
658            info_span!(
659                target: "rig::completions",
660                "chat_streaming",
661                gen_ai.operation.name = "chat_streaming",
662                gen_ai.provider.name = "galadriel",
663                gen_ai.request.model = self.model,
664                gen_ai.system_instructions = preamble,
665                gen_ai.response.id = tracing::field::Empty,
666                gen_ai.response.model = tracing::field::Empty,
667                gen_ai.usage.output_tokens = tracing::field::Empty,
668                gen_ai.usage.input_tokens = tracing::field::Empty,
669                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
670                gen_ai.output.messages = tracing::field::Empty,
671            )
672        } else {
673            tracing::Span::current()
674        };
675
676        send_compatible_streaming_request(self.client.http_client.clone(), req)
677            .instrument(span)
678            .await
679    }
680}