Skip to main content

rig_core/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::galadriel};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = galadriel::Client::new("YOUR_API_KEY")?;
9//! // to use a fine-tuned model
10//! // let client = galadriel::Client::builder()
11//! //     .api_key("YOUR_API_KEY")
12//! //     .fine_tune_api_key("FINE_TUNE_API_KEY")
13//! //     .build()?;
14//!
15//! let gpt4o = client.completion_model(galadriel::GPT_4O);
16//! # Ok(())
17//! # }
18//! ```
19use super::openai;
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::http_client::{self, HttpClientExt};
25use crate::message::MessageError;
26use crate::providers::openai::send_compatible_streaming_request;
27use crate::streaming::StreamingCompletionResponse;
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34use tracing::{Instrument, enabled, info_span};
35
36// ================================================================
37// Main Galadriel Client
38// ================================================================
39const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
40
41#[derive(Debug, Default, Clone)]
42pub struct GaladrielExt {
43    fine_tune_api_key: Option<String>,
44}
45
46#[derive(Debug, Default, Clone)]
47pub struct GaladrielBuilder {
48    fine_tune_api_key: Option<String>,
49}
50
51type GaladrielApiKey = BearerAuth;
52
53impl Provider for GaladrielExt {
54    type Builder = GaladrielBuilder;
55
56    /// There is currently no way to verify a Galadriel api key without consuming tokens
57    const VERIFY_PATH: &'static str = "";
58}
59
60impl<H> Capabilities<H> for GaladrielExt {
61    type Completion = Capable<CompletionModel<H>>;
62    type Embeddings = Nothing;
63    type Transcription = Nothing;
64    type ModelListing = Nothing;
65    #[cfg(feature = "image")]
66    type ImageGeneration = Nothing;
67    #[cfg(feature = "audio")]
68    type AudioGeneration = Nothing;
69    type Rerank = Nothing;
70}
71
72impl DebugExt for GaladrielExt {
73    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
74        std::iter::once((
75            "fine_tune_api_key",
76            (&self.fine_tune_api_key as &dyn std::fmt::Debug),
77        ))
78    }
79}
80
81impl ProviderBuilder for GaladrielBuilder {
82    type Extension<H>
83        = GaladrielExt
84    where
85        H: HttpClientExt;
86    type ApiKey = GaladrielApiKey;
87
88    const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
89
90    fn build<H>(
91        builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
92    ) -> http_client::Result<Self::Extension<H>>
93    where
94        H: HttpClientExt,
95    {
96        let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
97
98        Ok(GaladrielExt { fine_tune_api_key })
99    }
100}
101
102pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
103pub type ClientBuilder<H = crate::markers::Missing> =
104    client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
105
106impl<T> ClientBuilder<T> {
107    pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
108    where
109        S: AsRef<str>,
110    {
111        *self.ext_mut() = GaladrielBuilder {
112            fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
113        };
114
115        self
116    }
117}
118
119impl ProviderClient for Client {
120    type Input = (String, Option<String>);
121    type Error = crate::client::ProviderClientError;
122
123    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
124    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
125    fn from_env() -> Result<Self, Self::Error> {
126        let api_key = crate::client::required_env_var("GALADRIEL_API_KEY")?;
127        let fine_tune_api_key = crate::client::optional_env_var("GALADRIEL_FINE_TUNE_API_KEY")?;
128
129        let mut builder = Self::builder().api_key(api_key);
130
131        if let Some(fine_tune_api_key) = fine_tune_api_key {
132            builder = builder.fine_tune_api_key(fine_tune_api_key);
133        }
134
135        builder.build().map_err(Into::into)
136    }
137
138    fn from_val((api_key, fine_tune_api_key): Self::Input) -> Result<Self, Self::Error> {
139        let mut builder = Self::builder().api_key(api_key);
140
141        if let Some(fine_tune_key) = fine_tune_api_key {
142            builder = builder.fine_tune_api_key(fine_tune_key)
143        }
144
145        builder.build().map_err(Into::into)
146    }
147}
148
149#[derive(Debug, Deserialize)]
150struct ApiErrorResponse {
151    message: String,
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156enum ApiResponse<T> {
157    Ok(T),
158    Err(ApiErrorResponse),
159}
160
161#[derive(Clone, Debug, Deserialize, Serialize)]
162pub struct Usage {
163    pub prompt_tokens: usize,
164    pub total_tokens: usize,
165}
166
167impl std::fmt::Display for Usage {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        write!(
170            f,
171            "Prompt tokens: {} Total tokens: {}",
172            self.prompt_tokens, self.total_tokens
173        )
174    }
175}
176
177// ================================================================
178// Galadriel Completion API
179// ================================================================
180
181/// `o1-preview` completion model
182pub const O1_PREVIEW: &str = "o1-preview";
183/// `o1-preview-2024-09-12` completion model
184pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
185/// `o1-mini completion model
186pub const O1_MINI: &str = "o1-mini";
187/// `o1-mini-2024-09-12` completion model
188pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
189/// `gpt-4o` completion model
190pub const GPT_4O: &str = "gpt-4o";
191/// `gpt-4o-2024-05-13` completion model
192pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
193/// `gpt-4-turbo` completion model
194pub const GPT_4_TURBO: &str = "gpt-4-turbo";
195/// `gpt-4-turbo-2024-04-09` completion model
196pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
197/// `gpt-4-turbo-preview` completion model
198pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
199/// `gpt-4-0125-preview` completion model
200pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
201/// `gpt-4-1106-preview` completion model
202pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
203/// `gpt-4-vision-preview` completion model
204pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
205/// `gpt-4-1106-vision-preview` completion model
206pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
207/// `gpt-4` completion model
208pub const GPT_4: &str = "gpt-4";
209/// `gpt-4-0613` completion model
210pub const GPT_4_0613: &str = "gpt-4-0613";
211/// `gpt-4-32k` completion model
212pub const GPT_4_32K: &str = "gpt-4-32k";
213/// `gpt-4-32k-0613` completion model
214pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
215/// `gpt-3.5-turbo` completion model
216pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
217/// `gpt-3.5-turbo-0125` completion model
218pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
219/// `gpt-3.5-turbo-1106` completion model
220pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
221/// `gpt-3.5-turbo-instruct` completion model
222pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
223
224#[derive(Debug, Deserialize, Serialize)]
225pub struct CompletionResponse {
226    pub id: String,
227    pub object: String,
228    pub created: u64,
229    pub model: String,
230    pub system_fingerprint: Option<String>,
231    pub choices: Vec<Choice>,
232    pub usage: Option<Usage>,
233}
234
235impl From<ApiErrorResponse> for CompletionError {
236    fn from(err: ApiErrorResponse) -> Self {
237        CompletionError::ProviderError(err.message)
238    }
239}
240
241impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
242    type Error = CompletionError;
243
244    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
245        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
246            CompletionError::ResponseError("Response contained no choices".to_owned())
247        })?;
248
249        let mut content = message
250            .content
251            .as_ref()
252            .map(|c| vec![completion::AssistantContent::text(c)])
253            .unwrap_or_default();
254
255        content.extend(message.tool_calls.iter().map(|call| {
256            completion::AssistantContent::tool_call(
257                &call.function.name,
258                &call.function.name,
259                call.function.arguments.clone(),
260            )
261        }));
262
263        let choice = OneOrMany::many(content).map_err(|_| {
264            CompletionError::ResponseError(
265                "Response contained no message or tool call (empty)".to_owned(),
266            )
267        })?;
268        let usage = response
269            .usage
270            .as_ref()
271            .map(|usage| completion::Usage {
272                input_tokens: usage.prompt_tokens as u64,
273                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
274                total_tokens: usage.total_tokens as u64,
275                cached_input_tokens: 0,
276                cache_creation_input_tokens: 0,
277                tool_use_prompt_tokens: 0,
278                reasoning_tokens: 0,
279            })
280            .unwrap_or_default();
281
282        Ok(completion::CompletionResponse {
283            choice,
284            usage,
285            raw_response: response,
286            message_id: None,
287        })
288    }
289}
290
291#[derive(Debug, Deserialize, Serialize)]
292pub struct Choice {
293    pub index: usize,
294    pub message: Message,
295    pub logprobs: Option<serde_json::Value>,
296    pub finish_reason: String,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub struct Message {
301    pub role: String,
302    pub content: Option<String>,
303    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
304    pub tool_calls: Vec<openai::ToolCall>,
305}
306
307impl Message {
308    fn system(preamble: &str) -> Self {
309        Self {
310            role: "system".to_string(),
311            content: Some(preamble.to_string()),
312            tool_calls: Vec::new(),
313        }
314    }
315}
316
317impl TryFrom<Message> for message::Message {
318    type Error = message::MessageError;
319
320    fn try_from(message: Message) -> Result<Self, Self::Error> {
321        let tool_calls: Vec<message::ToolCall> = message
322            .tool_calls
323            .into_iter()
324            .map(|tool_call| tool_call.into())
325            .collect();
326
327        match message.role.as_str() {
328            "user" => Ok(Self::User {
329                content: OneOrMany::one(
330                    message
331                        .content
332                        .map(|content| message::UserContent::text(&content))
333                        .ok_or_else(|| {
334                            message::MessageError::ConversionError("Empty user message".to_string())
335                        })?,
336                ),
337            }),
338            "assistant" => Ok(Self::Assistant {
339                id: None,
340                content: OneOrMany::many(
341                    tool_calls
342                        .into_iter()
343                        .map(message::AssistantContent::ToolCall)
344                        .chain(
345                            message
346                                .content
347                                .map(|content| message::AssistantContent::text(&content))
348                                .into_iter(),
349                        ),
350                )
351                .map_err(|_| {
352                    message::MessageError::ConversionError("Empty assistant message".to_string())
353                })?,
354            }),
355            _ => Err(message::MessageError::ConversionError(format!(
356                "Unknown role: {}",
357                message.role
358            ))),
359        }
360    }
361}
362
363impl TryFrom<message::Message> for Message {
364    type Error = message::MessageError;
365
366    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
367        match message {
368            message::Message::System { content } => Ok(Self {
369                role: "system".to_string(),
370                content: Some(content),
371                tool_calls: vec![],
372            }),
373            message::Message::User { content } => Ok(Self {
374                role: "user".to_string(),
375                content: content.iter().find_map(|c| match c {
376                    message::UserContent::Text(text) => Some(text.text.clone()),
377                    _ => None,
378                }),
379                tool_calls: vec![],
380            }),
381            message::Message::Assistant { content, .. } => {
382                let mut text_content: Option<String> = None;
383                let mut tool_calls = vec![];
384
385                for c in content.iter() {
386                    match c {
387                        message::AssistantContent::Text(text) => {
388                            text_content = Some(
389                                text_content
390                                    .map(|mut existing| {
391                                        existing.push('\n');
392                                        existing.push_str(&text.text);
393                                        existing
394                                    })
395                                    .unwrap_or_else(|| text.text.clone()),
396                            );
397                        }
398                        message::AssistantContent::ToolCall(tool_call) => {
399                            tool_calls.push(tool_call.clone().into());
400                        }
401                        message::AssistantContent::Reasoning(_) => {
402                            return Err(MessageError::ConversionError(
403                                "Galadriel currently doesn't support reasoning.".into(),
404                            ));
405                        }
406                        message::AssistantContent::Image(_) => {
407                            return Err(MessageError::ConversionError(
408                                "Galadriel currently doesn't support images.".into(),
409                            ));
410                        }
411                    }
412                }
413
414                Ok(Self {
415                    role: "assistant".to_string(),
416                    content: text_content,
417                    tool_calls,
418                })
419            }
420        }
421    }
422}
423
424#[derive(Clone, Debug, Deserialize, Serialize)]
425pub struct ToolDefinition {
426    pub r#type: String,
427    pub function: completion::ToolDefinition,
428}
429
430impl From<completion::ToolDefinition> for ToolDefinition {
431    fn from(tool: completion::ToolDefinition) -> Self {
432        Self {
433            r#type: "function".into(),
434            function: tool,
435        }
436    }
437}
438
439#[derive(Debug, Deserialize)]
440pub struct Function {
441    pub name: String,
442    pub arguments: String,
443}
444
445#[derive(Debug, Serialize, Deserialize)]
446pub(super) struct GaladrielCompletionRequest {
447    model: String,
448    pub messages: Vec<Message>,
449    #[serde(skip_serializing_if = "Option::is_none")]
450    temperature: Option<f64>,
451    #[serde(skip_serializing_if = "Vec::is_empty")]
452    tools: Vec<ToolDefinition>,
453    #[serde(skip_serializing_if = "Option::is_none")]
454    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
455    #[serde(flatten, skip_serializing_if = "Option::is_none")]
456    pub additional_params: Option<serde_json::Value>,
457}
458
459impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
460    type Error = CompletionError;
461
462    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
463        let chat_history = req.chat_history_with_documents();
464        if req.output_schema.is_some() {
465            tracing::warn!("Structured outputs currently not supported for Galadriel");
466        }
467        let model = req.model.clone().unwrap_or_else(|| model.to_string());
468        // Build up the order of messages.
469        let mut partial_history = vec![];
470        partial_history.extend(chat_history);
471
472        // Add preamble to chat history (if available)
473        let mut full_history: Vec<Message> = match &req.preamble {
474            Some(preamble) => vec![Message::system(preamble)],
475            None => vec![],
476        };
477
478        // Convert and extend the rest of the history
479        full_history.extend(
480            partial_history
481                .into_iter()
482                .map(message::Message::try_into)
483                .collect::<Result<Vec<Message>, _>>()?,
484        );
485
486        let tool_choice = req
487            .tool_choice
488            .clone()
489            .map(crate::providers::openai::completion::ToolChoice::try_from)
490            .transpose()?;
491
492        Ok(Self {
493            model: model.to_string(),
494            messages: full_history,
495            temperature: req.temperature,
496            tools: req
497                .tools
498                .clone()
499                .into_iter()
500                .map(ToolDefinition::from)
501                .collect::<Vec<_>>(),
502            tool_choice,
503            additional_params: req.additional_params,
504        })
505    }
506}
507
508#[derive(Clone)]
509pub struct CompletionModel<T = reqwest::Client> {
510    client: Client<T>,
511    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
512    pub model: String,
513}
514
515impl<T> CompletionModel<T>
516where
517    T: HttpClientExt,
518{
519    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
520        Self {
521            client,
522            model: model.into(),
523        }
524    }
525
526    pub fn with_model(client: Client<T>, model: &str) -> Self {
527        Self {
528            client,
529            model: model.into(),
530        }
531    }
532}
533
534impl<T> completion::CompletionModel for CompletionModel<T>
535where
536    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
537{
538    type Response = CompletionResponse;
539    type StreamingResponse = openai::StreamingCompletionResponse;
540
541    type Client = Client<T>;
542
543    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
544        Self::new(client.clone(), model.into())
545    }
546
547    async fn completion(
548        &self,
549        completion_request: CompletionRequest,
550    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
551        let span = if tracing::Span::current().is_disabled() {
552            info_span!(
553                target: "rig::completions",
554                "chat",
555                gen_ai.operation.name = "chat",
556                gen_ai.provider.name = "galadriel",
557                gen_ai.request.model = self.model,
558                gen_ai.system_instructions = tracing::field::Empty,
559                gen_ai.response.id = tracing::field::Empty,
560                gen_ai.response.model = tracing::field::Empty,
561                gen_ai.usage.output_tokens = tracing::field::Empty,
562                gen_ai.usage.input_tokens = tracing::field::Empty,
563                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
564            )
565        } else {
566            tracing::Span::current()
567        };
568
569        span.record("gen_ai.system_instructions", &completion_request.preamble);
570
571        let request =
572            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
573
574        if enabled!(tracing::Level::TRACE) {
575            tracing::trace!(target: "rig::completions",
576                "Galadriel completion request: {}",
577                serde_json::to_string_pretty(&request)?
578            );
579        }
580
581        let body = serde_json::to_vec(&request)?;
582
583        let req = self
584            .client
585            .post("/chat/completions")?
586            .body(body)
587            .map_err(http_client::Error::from)?;
588
589        async move {
590            let response = self.client.send(req).await?;
591
592            if response.status().is_success() {
593                let t = http_client::text(response).await?;
594
595                if enabled!(tracing::Level::TRACE) {
596                    tracing::trace!(target: "rig::completions",
597                        "Galadriel completion response: {}",
598                        serde_json::to_string_pretty(&t)?
599                    );
600                }
601
602                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
603                    ApiResponse::Ok(response) => {
604                        let span = tracing::Span::current();
605                        span.record("gen_ai.response.id", response.id.clone());
606                        span.record("gen_ai.response.model", response.model.clone());
607                        if let Some(ref usage) = response.usage {
608                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
609                            span.record(
610                                "gen_ai.usage.output_tokens",
611                                usage.total_tokens - usage.prompt_tokens,
612                            );
613                        }
614                        response.try_into()
615                    }
616                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
617                }
618            } else {
619                let text = http_client::text(response).await?;
620
621                Err(CompletionError::ProviderError(text))
622            }
623        }
624        .instrument(span)
625        .await
626    }
627
628    async fn stream(
629        &self,
630        completion_request: CompletionRequest,
631    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
632        let preamble = completion_request.preamble.clone();
633        let mut request =
634            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
635
636        let params = json_utils::merge(
637            request.additional_params.unwrap_or(serde_json::json!({})),
638            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
639        );
640
641        request.additional_params = Some(params);
642
643        let body = serde_json::to_vec(&request)?;
644
645        let req = self
646            .client
647            .post("/chat/completions")?
648            .body(body)
649            .map_err(http_client::Error::from)?;
650
651        let span = if tracing::Span::current().is_disabled() {
652            info_span!(
653                target: "rig::completions",
654                "chat_streaming",
655                gen_ai.operation.name = "chat_streaming",
656                gen_ai.provider.name = "galadriel",
657                gen_ai.request.model = self.model,
658                gen_ai.system_instructions = preamble,
659                gen_ai.response.id = tracing::field::Empty,
660                gen_ai.response.model = tracing::field::Empty,
661                gen_ai.usage.output_tokens = tracing::field::Empty,
662                gen_ai.usage.input_tokens = tracing::field::Empty,
663                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
664                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
665                gen_ai.output.messages = tracing::field::Empty,
666            )
667        } else {
668            tracing::Span::current()
669        };
670
671        send_compatible_streaming_request(self.client.clone(), req)
672            .instrument(span)
673            .await
674    }
675}
676#[cfg(test)]
677mod tests {
678    #[test]
679    fn test_client_initialization() {
680        let _client =
681            crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
682        let _client_from_builder = crate::providers::galadriel::Client::builder()
683            .api_key("dummy-key")
684            .build()
685            .expect("Client::builder() failed");
686    }
687}