Skip to main content

rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{
15    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16    ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30// ================================================================
31// Main Galadriel Client
32// ================================================================
33const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37    fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42    fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48    type Builder = GaladrielBuilder;
49
50    /// There is currently no way to verify a Galadriel api key without consuming tokens
51    const VERIFY_PATH: &'static str = "";
52}
53
54impl<H> Capabilities<H> for GaladrielExt {
55    type Completion = Capable<CompletionModel<H>>;
56    type Embeddings = Nothing;
57    type Transcription = Nothing;
58    type ModelListing = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Nothing;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for GaladrielExt {
66    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
67        std::iter::once((
68            "fine_tune_api_key",
69            (&self.fine_tune_api_key as &dyn std::fmt::Debug),
70        ))
71    }
72}
73
74impl ProviderBuilder for GaladrielBuilder {
75    type Extension<H>
76        = GaladrielExt
77    where
78        H: HttpClientExt;
79    type ApiKey = GaladrielApiKey;
80
81    const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
82
83    fn build<H>(
84        builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
85    ) -> http_client::Result<Self::Extension<H>>
86    where
87        H: HttpClientExt,
88    {
89        let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
90
91        Ok(GaladrielExt { fine_tune_api_key })
92    }
93}
94
95pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
96pub type ClientBuilder<H = reqwest::Client> =
97    client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
98
99impl<T> ClientBuilder<T> {
100    pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
101    where
102        S: AsRef<str>,
103    {
104        *self.ext_mut() = GaladrielBuilder {
105            fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
106        };
107
108        self
109    }
110}
111
112impl ProviderClient for Client {
113    type Input = (String, Option<String>);
114
115    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
116    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
117    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
118    fn from_env() -> Self {
119        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
120        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
121
122        let mut builder = Self::builder().api_key(api_key);
123
124        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
125            builder = builder.fine_tune_api_key(fine_tune_api_key);
126        }
127
128        builder.build().unwrap()
129    }
130
131    fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
132        let mut builder = Self::builder().api_key(api_key);
133
134        if let Some(fine_tune_key) = fine_tune_api_key {
135            builder = builder.fine_tune_api_key(fine_tune_key)
136        }
137
138        builder.build().unwrap()
139    }
140}
141
142#[derive(Debug, Deserialize)]
143struct ApiErrorResponse {
144    message: String,
145}
146
147#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiResponse<T> {
150    Ok(T),
151    Err(ApiErrorResponse),
152}
153
154#[derive(Clone, Debug, Deserialize, Serialize)]
155pub struct Usage {
156    pub prompt_tokens: usize,
157    pub total_tokens: usize,
158}
159
160impl std::fmt::Display for Usage {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        write!(
163            f,
164            "Prompt tokens: {} Total tokens: {}",
165            self.prompt_tokens, self.total_tokens
166        )
167    }
168}
169
170// ================================================================
171// Galadriel Completion API
172// ================================================================
173
174/// `o1-preview` completion model
175pub const O1_PREVIEW: &str = "o1-preview";
176/// `o1-preview-2024-09-12` completion model
177pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
178/// `o1-mini completion model
179pub const O1_MINI: &str = "o1-mini";
180/// `o1-mini-2024-09-12` completion model
181pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
182/// `gpt-4o` completion model
183pub const GPT_4O: &str = "gpt-4o";
184/// `gpt-4o-2024-05-13` completion model
185pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
186/// `gpt-4-turbo` completion model
187pub const GPT_4_TURBO: &str = "gpt-4-turbo";
188/// `gpt-4-turbo-2024-04-09` completion model
189pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
190/// `gpt-4-turbo-preview` completion model
191pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
192/// `gpt-4-0125-preview` completion model
193pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
194/// `gpt-4-1106-preview` completion model
195pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
196/// `gpt-4-vision-preview` completion model
197pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
198/// `gpt-4-1106-vision-preview` completion model
199pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
200/// `gpt-4` completion model
201pub const GPT_4: &str = "gpt-4";
202/// `gpt-4-0613` completion model
203pub const GPT_4_0613: &str = "gpt-4-0613";
204/// `gpt-4-32k` completion model
205pub const GPT_4_32K: &str = "gpt-4-32k";
206/// `gpt-4-32k-0613` completion model
207pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
208/// `gpt-3.5-turbo` completion model
209pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
210/// `gpt-3.5-turbo-0125` completion model
211pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
212/// `gpt-3.5-turbo-1106` completion model
213pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
214/// `gpt-3.5-turbo-instruct` completion model
215pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
216
217#[derive(Debug, Deserialize, Serialize)]
218pub struct CompletionResponse {
219    pub id: String,
220    pub object: String,
221    pub created: u64,
222    pub model: String,
223    pub system_fingerprint: Option<String>,
224    pub choices: Vec<Choice>,
225    pub usage: Option<Usage>,
226}
227
228impl From<ApiErrorResponse> for CompletionError {
229    fn from(err: ApiErrorResponse) -> Self {
230        CompletionError::ProviderError(err.message)
231    }
232}
233
234impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
235    type Error = CompletionError;
236
237    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
238        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
239            CompletionError::ResponseError("Response contained no choices".to_owned())
240        })?;
241
242        let mut content = message
243            .content
244            .as_ref()
245            .map(|c| vec![completion::AssistantContent::text(c)])
246            .unwrap_or_default();
247
248        content.extend(message.tool_calls.iter().map(|call| {
249            completion::AssistantContent::tool_call(
250                &call.function.name,
251                &call.function.name,
252                call.function.arguments.clone(),
253            )
254        }));
255
256        let choice = OneOrMany::many(content).map_err(|_| {
257            CompletionError::ResponseError(
258                "Response contained no message or tool call (empty)".to_owned(),
259            )
260        })?;
261        let usage = response
262            .usage
263            .as_ref()
264            .map(|usage| completion::Usage {
265                input_tokens: usage.prompt_tokens as u64,
266                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
267                total_tokens: usage.total_tokens as u64,
268                cached_input_tokens: 0,
269                cache_creation_input_tokens: 0,
270            })
271            .unwrap_or_default();
272
273        Ok(completion::CompletionResponse {
274            choice,
275            usage,
276            raw_response: response,
277            message_id: None,
278        })
279    }
280}
281
282#[derive(Debug, Deserialize, Serialize)]
283pub struct Choice {
284    pub index: usize,
285    pub message: Message,
286    pub logprobs: Option<serde_json::Value>,
287    pub finish_reason: String,
288}
289
290#[derive(Debug, Serialize, Deserialize)]
291pub struct Message {
292    pub role: String,
293    pub content: Option<String>,
294    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
295    pub tool_calls: Vec<openai::ToolCall>,
296}
297
298impl Message {
299    fn system(preamble: &str) -> Self {
300        Self {
301            role: "system".to_string(),
302            content: Some(preamble.to_string()),
303            tool_calls: Vec::new(),
304        }
305    }
306}
307
308impl TryFrom<Message> for message::Message {
309    type Error = message::MessageError;
310
311    fn try_from(message: Message) -> Result<Self, Self::Error> {
312        let tool_calls: Vec<message::ToolCall> = message
313            .tool_calls
314            .into_iter()
315            .map(|tool_call| tool_call.into())
316            .collect();
317
318        match message.role.as_str() {
319            "user" => Ok(Self::User {
320                content: OneOrMany::one(
321                    message
322                        .content
323                        .map(|content| message::UserContent::text(&content))
324                        .ok_or_else(|| {
325                            message::MessageError::ConversionError("Empty user message".to_string())
326                        })?,
327                ),
328            }),
329            "assistant" => Ok(Self::Assistant {
330                id: None,
331                content: OneOrMany::many(
332                    tool_calls
333                        .into_iter()
334                        .map(message::AssistantContent::ToolCall)
335                        .chain(
336                            message
337                                .content
338                                .map(|content| message::AssistantContent::text(&content))
339                                .into_iter(),
340                        ),
341                )
342                .map_err(|_| {
343                    message::MessageError::ConversionError("Empty assistant message".to_string())
344                })?,
345            }),
346            _ => Err(message::MessageError::ConversionError(format!(
347                "Unknown role: {}",
348                message.role
349            ))),
350        }
351    }
352}
353
354impl TryFrom<message::Message> for Message {
355    type Error = message::MessageError;
356
357    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
358        match message {
359            message::Message::System { content } => Ok(Self {
360                role: "system".to_string(),
361                content: Some(content),
362                tool_calls: vec![],
363            }),
364            message::Message::User { content } => Ok(Self {
365                role: "user".to_string(),
366                content: content.iter().find_map(|c| match c {
367                    message::UserContent::Text(text) => Some(text.text.clone()),
368                    _ => None,
369                }),
370                tool_calls: vec![],
371            }),
372            message::Message::Assistant { content, .. } => {
373                let mut text_content: Option<String> = None;
374                let mut tool_calls = vec![];
375
376                for c in content.iter() {
377                    match c {
378                        message::AssistantContent::Text(text) => {
379                            text_content = Some(
380                                text_content
381                                    .map(|mut existing| {
382                                        existing.push('\n');
383                                        existing.push_str(&text.text);
384                                        existing
385                                    })
386                                    .unwrap_or_else(|| text.text.clone()),
387                            );
388                        }
389                        message::AssistantContent::ToolCall(tool_call) => {
390                            tool_calls.push(tool_call.clone().into());
391                        }
392                        message::AssistantContent::Reasoning(_) => {
393                            return Err(MessageError::ConversionError(
394                                "Galadriel currently doesn't support reasoning.".into(),
395                            ));
396                        }
397                        message::AssistantContent::Image(_) => {
398                            return Err(MessageError::ConversionError(
399                                "Galadriel currently doesn't support images.".into(),
400                            ));
401                        }
402                    }
403                }
404
405                Ok(Self {
406                    role: "assistant".to_string(),
407                    content: text_content,
408                    tool_calls,
409                })
410            }
411        }
412    }
413}
414
415#[derive(Clone, Debug, Deserialize, Serialize)]
416pub struct ToolDefinition {
417    pub r#type: String,
418    pub function: completion::ToolDefinition,
419}
420
421impl From<completion::ToolDefinition> for ToolDefinition {
422    fn from(tool: completion::ToolDefinition) -> Self {
423        Self {
424            r#type: "function".into(),
425            function: tool,
426        }
427    }
428}
429
430#[derive(Debug, Deserialize)]
431pub struct Function {
432    pub name: String,
433    pub arguments: String,
434}
435
436#[derive(Debug, Serialize, Deserialize)]
437pub(super) struct GaladrielCompletionRequest {
438    model: String,
439    pub messages: Vec<Message>,
440    #[serde(skip_serializing_if = "Option::is_none")]
441    temperature: Option<f64>,
442    #[serde(skip_serializing_if = "Vec::is_empty")]
443    tools: Vec<ToolDefinition>,
444    #[serde(skip_serializing_if = "Option::is_none")]
445    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
446    #[serde(flatten, skip_serializing_if = "Option::is_none")]
447    pub additional_params: Option<serde_json::Value>,
448}
449
450impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
451    type Error = CompletionError;
452
453    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
454        if req.output_schema.is_some() {
455            tracing::warn!("Structured outputs currently not supported for Galadriel");
456        }
457        let model = req.model.clone().unwrap_or_else(|| model.to_string());
458        // Build up the order of messages (context, chat_history, prompt)
459        let mut partial_history = vec![];
460        if let Some(docs) = req.normalized_documents() {
461            partial_history.push(docs);
462        }
463        partial_history.extend(req.chat_history);
464
465        // Add preamble to chat history (if available)
466        let mut full_history: Vec<Message> = match &req.preamble {
467            Some(preamble) => vec![Message::system(preamble)],
468            None => vec![],
469        };
470
471        // Convert and extend the rest of the history
472        full_history.extend(
473            partial_history
474                .into_iter()
475                .map(message::Message::try_into)
476                .collect::<Result<Vec<Message>, _>>()?,
477        );
478
479        let tool_choice = req
480            .tool_choice
481            .clone()
482            .map(crate::providers::openai::completion::ToolChoice::try_from)
483            .transpose()?;
484
485        Ok(Self {
486            model: model.to_string(),
487            messages: full_history,
488            temperature: req.temperature,
489            tools: req
490                .tools
491                .clone()
492                .into_iter()
493                .map(ToolDefinition::from)
494                .collect::<Vec<_>>(),
495            tool_choice,
496            additional_params: req.additional_params,
497        })
498    }
499}
500
501#[derive(Clone)]
502pub struct CompletionModel<T = reqwest::Client> {
503    client: Client<T>,
504    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
505    pub model: String,
506}
507
508impl<T> CompletionModel<T>
509where
510    T: HttpClientExt,
511{
512    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
513        Self {
514            client,
515            model: model.into(),
516        }
517    }
518
519    pub fn with_model(client: Client<T>, model: &str) -> Self {
520        Self {
521            client,
522            model: model.into(),
523        }
524    }
525}
526
527impl<T> completion::CompletionModel for CompletionModel<T>
528where
529    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
530{
531    type Response = CompletionResponse;
532    type StreamingResponse = openai::StreamingCompletionResponse;
533
534    type Client = Client<T>;
535
536    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
537        Self::new(client.clone(), model.into())
538    }
539
540    async fn completion(
541        &self,
542        completion_request: CompletionRequest,
543    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
544        let span = if tracing::Span::current().is_disabled() {
545            info_span!(
546                target: "rig::completions",
547                "chat",
548                gen_ai.operation.name = "chat",
549                gen_ai.provider.name = "galadriel",
550                gen_ai.request.model = self.model,
551                gen_ai.system_instructions = tracing::field::Empty,
552                gen_ai.response.id = tracing::field::Empty,
553                gen_ai.response.model = tracing::field::Empty,
554                gen_ai.usage.output_tokens = tracing::field::Empty,
555                gen_ai.usage.input_tokens = tracing::field::Empty,
556                gen_ai.usage.cached_tokens = tracing::field::Empty,
557            )
558        } else {
559            tracing::Span::current()
560        };
561
562        span.record("gen_ai.system_instructions", &completion_request.preamble);
563
564        let request =
565            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
566
567        if enabled!(tracing::Level::TRACE) {
568            tracing::trace!(target: "rig::completions",
569                "Galadriel completion request: {}",
570                serde_json::to_string_pretty(&request)?
571            );
572        }
573
574        let body = serde_json::to_vec(&request)?;
575
576        let req = self
577            .client
578            .post("/chat/completions")?
579            .body(body)
580            .map_err(http_client::Error::from)?;
581
582        async move {
583            let response = self.client.send(req).await?;
584
585            if response.status().is_success() {
586                let t = http_client::text(response).await?;
587
588                if enabled!(tracing::Level::TRACE) {
589                    tracing::trace!(target: "rig::completions",
590                        "Galadriel completion response: {}",
591                        serde_json::to_string_pretty(&t)?
592                    );
593                }
594
595                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
596                    ApiResponse::Ok(response) => {
597                        let span = tracing::Span::current();
598                        span.record("gen_ai.response.id", response.id.clone());
599                        span.record("gen_ai.response.model_name", response.model.clone());
600                        if let Some(ref usage) = response.usage {
601                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
602                            span.record(
603                                "gen_ai.usage.output_tokens",
604                                usage.total_tokens - usage.prompt_tokens,
605                            );
606                        }
607                        response.try_into()
608                    }
609                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
610                }
611            } else {
612                let text = http_client::text(response).await?;
613
614                Err(CompletionError::ProviderError(text))
615            }
616        }
617        .instrument(span)
618        .await
619    }
620
621    async fn stream(
622        &self,
623        completion_request: CompletionRequest,
624    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
625        let preamble = completion_request.preamble.clone();
626        let mut request =
627            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
628
629        let params = json_utils::merge(
630            request.additional_params.unwrap_or(serde_json::json!({})),
631            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
632        );
633
634        request.additional_params = Some(params);
635
636        let body = serde_json::to_vec(&request)?;
637
638        let req = self
639            .client
640            .post("/chat/completions")?
641            .body(body)
642            .map_err(http_client::Error::from)?;
643
644        let span = if tracing::Span::current().is_disabled() {
645            info_span!(
646                target: "rig::completions",
647                "chat_streaming",
648                gen_ai.operation.name = "chat_streaming",
649                gen_ai.provider.name = "galadriel",
650                gen_ai.request.model = self.model,
651                gen_ai.system_instructions = preamble,
652                gen_ai.response.id = tracing::field::Empty,
653                gen_ai.response.model = tracing::field::Empty,
654                gen_ai.usage.output_tokens = tracing::field::Empty,
655                gen_ai.usage.input_tokens = tracing::field::Empty,
656                gen_ai.usage.cached_tokens = tracing::field::Empty,
657                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
658                gen_ai.output.messages = tracing::field::Empty,
659            )
660        } else {
661            tracing::Span::current()
662        };
663
664        send_compatible_streaming_request(self.client.clone(), req)
665            .instrument(span)
666            .await
667    }
668}
669#[cfg(test)]
670mod tests {
671    #[test]
672    fn test_client_initialization() {
673        let _client =
674            crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
675        let _client_from_builder = crate::providers::galadriel::Client::builder()
676            .api_key("dummy-key")
677            .build()
678            .expect("Client::builder() failed");
679    }
680}