Skip to main content

rig_core/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::galadriel};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = galadriel::Client::new("YOUR_API_KEY")?;
9//! // to use a fine-tuned model
10//! // let client = galadriel::Client::builder()
11//! //     .api_key("YOUR_API_KEY")
12//! //     .fine_tune_api_key("FINE_TUNE_API_KEY")
13//! //     .build()?;
14//!
15//! let gpt4o = client.completion_model(galadriel::GPT_4O);
16//! # Ok(())
17//! # }
18//! ```
19use super::openai;
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::http_client::{self, HttpClientExt};
25use crate::message::MessageError;
26use crate::providers::openai::send_compatible_streaming_request;
27use crate::streaming::StreamingCompletionResponse;
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34use tracing::{Instrument, enabled, info_span};
35
36// ================================================================
37// Main Galadriel Client
38// ================================================================
39const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
40
41#[derive(Debug, Default, Clone)]
42pub struct GaladrielExt {
43    fine_tune_api_key: Option<String>,
44}
45
46#[derive(Debug, Default, Clone)]
47pub struct GaladrielBuilder {
48    fine_tune_api_key: Option<String>,
49}
50
51type GaladrielApiKey = BearerAuth;
52
53impl Provider for GaladrielExt {
54    type Builder = GaladrielBuilder;
55
56    /// There is currently no way to verify a Galadriel api key without consuming tokens
57    const VERIFY_PATH: &'static str = "";
58}
59
60impl<H> Capabilities<H> for GaladrielExt {
61    type Completion = Capable<CompletionModel<H>>;
62    type Embeddings = Nothing;
63    type Transcription = Nothing;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Nothing;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Nothing;
69}
70
71impl DebugExt for GaladrielExt {
72    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
73        std::iter::once((
74            "fine_tune_api_key",
75            (&self.fine_tune_api_key as &dyn std::fmt::Debug),
76        ))
77    }
78}
79
80impl ProviderBuilder for GaladrielBuilder {
81    type Extension<H>
82        = GaladrielExt
83    where
84        H: HttpClientExt;
85    type ApiKey = GaladrielApiKey;
86
87    const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
88
89    fn build<H>(
90        builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
91    ) -> http_client::Result<Self::Extension<H>>
92    where
93        H: HttpClientExt,
94    {
95        let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
96
97        Ok(GaladrielExt { fine_tune_api_key })
98    }
99}
100
101pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
102pub type ClientBuilder<H = crate::markers::Missing> =
103    client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
104
105impl<T> ClientBuilder<T> {
106    pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
107    where
108        S: AsRef<str>,
109    {
110        *self.ext_mut() = GaladrielBuilder {
111            fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
112        };
113
114        self
115    }
116}
117
118impl ProviderClient for Client {
119    type Input = (String, Option<String>);
120    type Error = crate::client::ProviderClientError;
121
122    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
123    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
124    fn from_env() -> Result<Self, Self::Error> {
125        let api_key = crate::client::required_env_var("GALADRIEL_API_KEY")?;
126        let fine_tune_api_key = crate::client::optional_env_var("GALADRIEL_FINE_TUNE_API_KEY")?;
127
128        let mut builder = Self::builder().api_key(api_key);
129
130        if let Some(fine_tune_api_key) = fine_tune_api_key {
131            builder = builder.fine_tune_api_key(fine_tune_api_key);
132        }
133
134        builder.build().map_err(Into::into)
135    }
136
137    fn from_val((api_key, fine_tune_api_key): Self::Input) -> Result<Self, Self::Error> {
138        let mut builder = Self::builder().api_key(api_key);
139
140        if let Some(fine_tune_key) = fine_tune_api_key {
141            builder = builder.fine_tune_api_key(fine_tune_key)
142        }
143
144        builder.build().map_err(Into::into)
145    }
146}
147
148#[derive(Debug, Deserialize)]
149struct ApiErrorResponse {
150    message: String,
151}
152
153#[derive(Debug, Deserialize)]
154#[serde(untagged)]
155enum ApiResponse<T> {
156    Ok(T),
157    Err(ApiErrorResponse),
158}
159
160#[derive(Clone, Debug, Deserialize, Serialize)]
161pub struct Usage {
162    pub prompt_tokens: usize,
163    pub total_tokens: usize,
164}
165
166impl std::fmt::Display for Usage {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(
169            f,
170            "Prompt tokens: {} Total tokens: {}",
171            self.prompt_tokens, self.total_tokens
172        )
173    }
174}
175
176// ================================================================
177// Galadriel Completion API
178// ================================================================
179
180/// `o1-preview` completion model
181pub const O1_PREVIEW: &str = "o1-preview";
182/// `o1-preview-2024-09-12` completion model
183pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
184/// `o1-mini completion model
185pub const O1_MINI: &str = "o1-mini";
186/// `o1-mini-2024-09-12` completion model
187pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
188/// `gpt-4o` completion model
189pub const GPT_4O: &str = "gpt-4o";
190/// `gpt-4o-2024-05-13` completion model
191pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
192/// `gpt-4-turbo` completion model
193pub const GPT_4_TURBO: &str = "gpt-4-turbo";
194/// `gpt-4-turbo-2024-04-09` completion model
195pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
196/// `gpt-4-turbo-preview` completion model
197pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
198/// `gpt-4-0125-preview` completion model
199pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
200/// `gpt-4-1106-preview` completion model
201pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
202/// `gpt-4-vision-preview` completion model
203pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
204/// `gpt-4-1106-vision-preview` completion model
205pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
206/// `gpt-4` completion model
207pub const GPT_4: &str = "gpt-4";
208/// `gpt-4-0613` completion model
209pub const GPT_4_0613: &str = "gpt-4-0613";
210/// `gpt-4-32k` completion model
211pub const GPT_4_32K: &str = "gpt-4-32k";
212/// `gpt-4-32k-0613` completion model
213pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
214/// `gpt-3.5-turbo` completion model
215pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
216/// `gpt-3.5-turbo-0125` completion model
217pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
218/// `gpt-3.5-turbo-1106` completion model
219pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
220/// `gpt-3.5-turbo-instruct` completion model
221pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
222
223#[derive(Debug, Deserialize, Serialize)]
224pub struct CompletionResponse {
225    pub id: String,
226    pub object: String,
227    pub created: u64,
228    pub model: String,
229    pub system_fingerprint: Option<String>,
230    pub choices: Vec<Choice>,
231    pub usage: Option<Usage>,
232}
233
234impl From<ApiErrorResponse> for CompletionError {
235    fn from(err: ApiErrorResponse) -> Self {
236        CompletionError::ProviderError(err.message)
237    }
238}
239
240impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
241    type Error = CompletionError;
242
243    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
244        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
245            CompletionError::ResponseError("Response contained no choices".to_owned())
246        })?;
247
248        let mut content = message
249            .content
250            .as_ref()
251            .map(|c| vec![completion::AssistantContent::text(c)])
252            .unwrap_or_default();
253
254        content.extend(message.tool_calls.iter().map(|call| {
255            completion::AssistantContent::tool_call(
256                &call.function.name,
257                &call.function.name,
258                call.function.arguments.clone(),
259            )
260        }));
261
262        let choice = OneOrMany::many(content).map_err(|_| {
263            CompletionError::ResponseError(
264                "Response contained no message or tool call (empty)".to_owned(),
265            )
266        })?;
267        let usage = response
268            .usage
269            .as_ref()
270            .map(|usage| completion::Usage {
271                input_tokens: usage.prompt_tokens as u64,
272                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
273                total_tokens: usage.total_tokens as u64,
274                cached_input_tokens: 0,
275                cache_creation_input_tokens: 0,
276                tool_use_prompt_tokens: 0,
277                reasoning_tokens: 0,
278            })
279            .unwrap_or_default();
280
281        Ok(completion::CompletionResponse {
282            choice,
283            usage,
284            raw_response: response,
285            message_id: None,
286        })
287    }
288}
289
290#[derive(Debug, Deserialize, Serialize)]
291pub struct Choice {
292    pub index: usize,
293    pub message: Message,
294    pub logprobs: Option<serde_json::Value>,
295    pub finish_reason: String,
296}
297
298#[derive(Debug, Serialize, Deserialize)]
299pub struct Message {
300    pub role: String,
301    pub content: Option<String>,
302    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
303    pub tool_calls: Vec<openai::ToolCall>,
304}
305
306impl Message {
307    fn system(preamble: &str) -> Self {
308        Self {
309            role: "system".to_string(),
310            content: Some(preamble.to_string()),
311            tool_calls: Vec::new(),
312        }
313    }
314}
315
316impl TryFrom<Message> for message::Message {
317    type Error = message::MessageError;
318
319    fn try_from(message: Message) -> Result<Self, Self::Error> {
320        let tool_calls: Vec<message::ToolCall> = message
321            .tool_calls
322            .into_iter()
323            .map(|tool_call| tool_call.into())
324            .collect();
325
326        match message.role.as_str() {
327            "user" => Ok(Self::User {
328                content: OneOrMany::one(
329                    message
330                        .content
331                        .map(|content| message::UserContent::text(&content))
332                        .ok_or_else(|| {
333                            message::MessageError::ConversionError("Empty user message".to_string())
334                        })?,
335                ),
336            }),
337            "assistant" => Ok(Self::Assistant {
338                id: None,
339                content: OneOrMany::many(
340                    tool_calls
341                        .into_iter()
342                        .map(message::AssistantContent::ToolCall)
343                        .chain(
344                            message
345                                .content
346                                .map(|content| message::AssistantContent::text(&content))
347                                .into_iter(),
348                        ),
349                )
350                .map_err(|_| {
351                    message::MessageError::ConversionError("Empty assistant message".to_string())
352                })?,
353            }),
354            _ => Err(message::MessageError::ConversionError(format!(
355                "Unknown role: {}",
356                message.role
357            ))),
358        }
359    }
360}
361
362impl TryFrom<message::Message> for Message {
363    type Error = message::MessageError;
364
365    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
366        match message {
367            message::Message::System { content } => Ok(Self {
368                role: "system".to_string(),
369                content: Some(content),
370                tool_calls: vec![],
371            }),
372            message::Message::User { content } => Ok(Self {
373                role: "user".to_string(),
374                content: content.iter().find_map(|c| match c {
375                    message::UserContent::Text(text) => Some(text.text.clone()),
376                    _ => None,
377                }),
378                tool_calls: vec![],
379            }),
380            message::Message::Assistant { content, .. } => {
381                let mut text_content: Option<String> = None;
382                let mut tool_calls = vec![];
383
384                for c in content.iter() {
385                    match c {
386                        message::AssistantContent::Text(text) => {
387                            text_content = Some(
388                                text_content
389                                    .map(|mut existing| {
390                                        existing.push('\n');
391                                        existing.push_str(&text.text);
392                                        existing
393                                    })
394                                    .unwrap_or_else(|| text.text.clone()),
395                            );
396                        }
397                        message::AssistantContent::ToolCall(tool_call) => {
398                            tool_calls.push(tool_call.clone().into());
399                        }
400                        message::AssistantContent::Reasoning(_) => {
401                            return Err(MessageError::ConversionError(
402                                "Galadriel currently doesn't support reasoning.".into(),
403                            ));
404                        }
405                        message::AssistantContent::Image(_) => {
406                            return Err(MessageError::ConversionError(
407                                "Galadriel currently doesn't support images.".into(),
408                            ));
409                        }
410                    }
411                }
412
413                Ok(Self {
414                    role: "assistant".to_string(),
415                    content: text_content,
416                    tool_calls,
417                })
418            }
419        }
420    }
421}
422
423#[derive(Clone, Debug, Deserialize, Serialize)]
424pub struct ToolDefinition {
425    pub r#type: String,
426    pub function: completion::ToolDefinition,
427}
428
429impl From<completion::ToolDefinition> for ToolDefinition {
430    fn from(tool: completion::ToolDefinition) -> Self {
431        Self {
432            r#type: "function".into(),
433            function: tool,
434        }
435    }
436}
437
438#[derive(Debug, Deserialize)]
439pub struct Function {
440    pub name: String,
441    pub arguments: String,
442}
443
444#[derive(Debug, Serialize, Deserialize)]
445pub(super) struct GaladrielCompletionRequest {
446    model: String,
447    pub messages: Vec<Message>,
448    #[serde(skip_serializing_if = "Option::is_none")]
449    temperature: Option<f64>,
450    #[serde(skip_serializing_if = "Vec::is_empty")]
451    tools: Vec<ToolDefinition>,
452    #[serde(skip_serializing_if = "Option::is_none")]
453    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
454    #[serde(flatten, skip_serializing_if = "Option::is_none")]
455    pub additional_params: Option<serde_json::Value>,
456}
457
458impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
459    type Error = CompletionError;
460
461    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
462        if req.output_schema.is_some() {
463            tracing::warn!("Structured outputs currently not supported for Galadriel");
464        }
465        let model = req.model.clone().unwrap_or_else(|| model.to_string());
466        // Build up the order of messages (context, chat_history, prompt)
467        let mut partial_history = vec![];
468        if let Some(docs) = req.normalized_documents() {
469            partial_history.push(docs);
470        }
471        partial_history.extend(req.chat_history);
472
473        // Add preamble to chat history (if available)
474        let mut full_history: Vec<Message> = match &req.preamble {
475            Some(preamble) => vec![Message::system(preamble)],
476            None => vec![],
477        };
478
479        // Convert and extend the rest of the history
480        full_history.extend(
481            partial_history
482                .into_iter()
483                .map(message::Message::try_into)
484                .collect::<Result<Vec<Message>, _>>()?,
485        );
486
487        let tool_choice = req
488            .tool_choice
489            .clone()
490            .map(crate::providers::openai::completion::ToolChoice::try_from)
491            .transpose()?;
492
493        Ok(Self {
494            model: model.to_string(),
495            messages: full_history,
496            temperature: req.temperature,
497            tools: req
498                .tools
499                .clone()
500                .into_iter()
501                .map(ToolDefinition::from)
502                .collect::<Vec<_>>(),
503            tool_choice,
504            additional_params: req.additional_params,
505        })
506    }
507}
508
509#[derive(Clone)]
510pub struct CompletionModel<T = reqwest::Client> {
511    client: Client<T>,
512    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
513    pub model: String,
514}
515
516impl<T> CompletionModel<T>
517where
518    T: HttpClientExt,
519{
520    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
521        Self {
522            client,
523            model: model.into(),
524        }
525    }
526
527    pub fn with_model(client: Client<T>, model: &str) -> Self {
528        Self {
529            client,
530            model: model.into(),
531        }
532    }
533}
534
535impl<T> completion::CompletionModel for CompletionModel<T>
536where
537    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
538{
539    type Response = CompletionResponse;
540    type StreamingResponse = openai::StreamingCompletionResponse;
541
542    type Client = Client<T>;
543
544    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
545        Self::new(client.clone(), model.into())
546    }
547
548    async fn completion(
549        &self,
550        completion_request: CompletionRequest,
551    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
552        let span = if tracing::Span::current().is_disabled() {
553            info_span!(
554                target: "rig::completions",
555                "chat",
556                gen_ai.operation.name = "chat",
557                gen_ai.provider.name = "galadriel",
558                gen_ai.request.model = self.model,
559                gen_ai.system_instructions = tracing::field::Empty,
560                gen_ai.response.id = tracing::field::Empty,
561                gen_ai.response.model = tracing::field::Empty,
562                gen_ai.usage.output_tokens = tracing::field::Empty,
563                gen_ai.usage.input_tokens = tracing::field::Empty,
564                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
565            )
566        } else {
567            tracing::Span::current()
568        };
569
570        span.record("gen_ai.system_instructions", &completion_request.preamble);
571
572        let request =
573            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
574
575        if enabled!(tracing::Level::TRACE) {
576            tracing::trace!(target: "rig::completions",
577                "Galadriel completion request: {}",
578                serde_json::to_string_pretty(&request)?
579            );
580        }
581
582        let body = serde_json::to_vec(&request)?;
583
584        let req = self
585            .client
586            .post("/chat/completions")?
587            .body(body)
588            .map_err(http_client::Error::from)?;
589
590        async move {
591            let response = self.client.send(req).await?;
592
593            if response.status().is_success() {
594                let t = http_client::text(response).await?;
595
596                if enabled!(tracing::Level::TRACE) {
597                    tracing::trace!(target: "rig::completions",
598                        "Galadriel completion response: {}",
599                        serde_json::to_string_pretty(&t)?
600                    );
601                }
602
603                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
604                    ApiResponse::Ok(response) => {
605                        let span = tracing::Span::current();
606                        span.record("gen_ai.response.id", response.id.clone());
607                        span.record("gen_ai.response.model", response.model.clone());
608                        if let Some(ref usage) = response.usage {
609                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
610                            span.record(
611                                "gen_ai.usage.output_tokens",
612                                usage.total_tokens - usage.prompt_tokens,
613                            );
614                        }
615                        response.try_into()
616                    }
617                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
618                }
619            } else {
620                let text = http_client::text(response).await?;
621
622                Err(CompletionError::ProviderError(text))
623            }
624        }
625        .instrument(span)
626        .await
627    }
628
629    async fn stream(
630        &self,
631        completion_request: CompletionRequest,
632    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
633        let preamble = completion_request.preamble.clone();
634        let mut request =
635            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
636
637        let params = json_utils::merge(
638            request.additional_params.unwrap_or(serde_json::json!({})),
639            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
640        );
641
642        request.additional_params = Some(params);
643
644        let body = serde_json::to_vec(&request)?;
645
646        let req = self
647            .client
648            .post("/chat/completions")?
649            .body(body)
650            .map_err(http_client::Error::from)?;
651
652        let span = if tracing::Span::current().is_disabled() {
653            info_span!(
654                target: "rig::completions",
655                "chat_streaming",
656                gen_ai.operation.name = "chat_streaming",
657                gen_ai.provider.name = "galadriel",
658                gen_ai.request.model = self.model,
659                gen_ai.system_instructions = preamble,
660                gen_ai.response.id = tracing::field::Empty,
661                gen_ai.response.model = tracing::field::Empty,
662                gen_ai.usage.output_tokens = tracing::field::Empty,
663                gen_ai.usage.input_tokens = tracing::field::Empty,
664                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
665                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
666                gen_ai.output.messages = tracing::field::Empty,
667            )
668        } else {
669            tracing::Span::current()
670        };
671
672        send_compatible_streaming_request(self.client.clone(), req)
673            .instrument(span)
674            .await
675    }
676}
677#[cfg(test)]
678mod tests {
679    #[test]
680    fn test_client_initialization() {
681        let _client =
682            crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
683        let _client_from_builder = crate::providers::galadriel::Client::builder()
684            .api_key("dummy-key")
685            .build()
686            .expect("Client::builder() failed");
687    }
688}