Skip to main content

rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{
15    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16    ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30// ================================================================
31// Main Galadriel Client
32// ================================================================
33const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37    fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42    fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48    type Builder = GaladrielBuilder;
49
50    /// There is currently no way to verify a Galadriel api key without consuming tokens
51    const VERIFY_PATH: &'static str = "";
52
53    fn build<H>(
54        builder: &crate::client::ClientBuilder<
55            Self::Builder,
56            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
57            H,
58        >,
59    ) -> http_client::Result<Self> {
60        let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
61
62        Ok(Self { fine_tune_api_key })
63    }
64}
65
66impl<H> Capabilities<H> for GaladrielExt {
67    type Completion = Capable<CompletionModel<H>>;
68    type Embeddings = Nothing;
69    type Transcription = Nothing;
70    #[cfg(feature = "image")]
71    type ImageGeneration = Nothing;
72    #[cfg(feature = "audio")]
73    type AudioGeneration = Nothing;
74}
75
76impl DebugExt for GaladrielExt {
77    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
78        std::iter::once((
79            "fine_tune_api_key",
80            (&self.fine_tune_api_key as &dyn std::fmt::Debug),
81        ))
82    }
83}
84
85impl ProviderBuilder for GaladrielBuilder {
86    type Output = GaladrielExt;
87    type ApiKey = GaladrielApiKey;
88
89    const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
90}
91
92pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
93pub type ClientBuilder<H = reqwest::Client> =
94    client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
95
96impl<T> ClientBuilder<T> {
97    pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
98    where
99        S: AsRef<str>,
100    {
101        *self.ext_mut() = GaladrielBuilder {
102            fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
103        };
104
105        self
106    }
107}
108
109impl ProviderClient for Client {
110    type Input = (String, Option<String>);
111
112    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
113    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
114    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
115    fn from_env() -> Self {
116        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
117        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
118
119        let mut builder = Self::builder().api_key(api_key);
120
121        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
122            builder = builder.fine_tune_api_key(fine_tune_api_key);
123        }
124
125        builder.build().unwrap()
126    }
127
128    fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
129        let mut builder = Self::builder().api_key(api_key);
130
131        if let Some(fine_tune_key) = fine_tune_api_key {
132            builder = builder.fine_tune_api_key(fine_tune_key)
133        }
134
135        builder.build().unwrap()
136    }
137}
138
139#[derive(Debug, Deserialize)]
140struct ApiErrorResponse {
141    message: String,
142}
143
144#[derive(Debug, Deserialize)]
145#[serde(untagged)]
146enum ApiResponse<T> {
147    Ok(T),
148    Err(ApiErrorResponse),
149}
150
151#[derive(Clone, Debug, Deserialize, Serialize)]
152pub struct Usage {
153    pub prompt_tokens: usize,
154    pub total_tokens: usize,
155}
156
157impl std::fmt::Display for Usage {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        write!(
160            f,
161            "Prompt tokens: {} Total tokens: {}",
162            self.prompt_tokens, self.total_tokens
163        )
164    }
165}
166
167// ================================================================
168// Galadriel Completion API
169// ================================================================
170
171/// `o1-preview` completion model
172pub const O1_PREVIEW: &str = "o1-preview";
173/// `o1-preview-2024-09-12` completion model
174pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
175/// `o1-mini completion model
176pub const O1_MINI: &str = "o1-mini";
177/// `o1-mini-2024-09-12` completion model
178pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
179/// `gpt-4o` completion model
180pub const GPT_4O: &str = "gpt-4o";
181/// `gpt-4o-2024-05-13` completion model
182pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
183/// `gpt-4-turbo` completion model
184pub const GPT_4_TURBO: &str = "gpt-4-turbo";
185/// `gpt-4-turbo-2024-04-09` completion model
186pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
187/// `gpt-4-turbo-preview` completion model
188pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
189/// `gpt-4-0125-preview` completion model
190pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
191/// `gpt-4-1106-preview` completion model
192pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
193/// `gpt-4-vision-preview` completion model
194pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
195/// `gpt-4-1106-vision-preview` completion model
196pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
197/// `gpt-4` completion model
198pub const GPT_4: &str = "gpt-4";
199/// `gpt-4-0613` completion model
200pub const GPT_4_0613: &str = "gpt-4-0613";
201/// `gpt-4-32k` completion model
202pub const GPT_4_32K: &str = "gpt-4-32k";
203/// `gpt-4-32k-0613` completion model
204pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
205/// `gpt-3.5-turbo` completion model
206pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
207/// `gpt-3.5-turbo-0125` completion model
208pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
209/// `gpt-3.5-turbo-1106` completion model
210pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
211/// `gpt-3.5-turbo-instruct` completion model
212pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
213
214#[derive(Debug, Deserialize, Serialize)]
215pub struct CompletionResponse {
216    pub id: String,
217    pub object: String,
218    pub created: u64,
219    pub model: String,
220    pub system_fingerprint: Option<String>,
221    pub choices: Vec<Choice>,
222    pub usage: Option<Usage>,
223}
224
225impl From<ApiErrorResponse> for CompletionError {
226    fn from(err: ApiErrorResponse) -> Self {
227        CompletionError::ProviderError(err.message)
228    }
229}
230
231impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
232    type Error = CompletionError;
233
234    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
235        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
236            CompletionError::ResponseError("Response contained no choices".to_owned())
237        })?;
238
239        let mut content = message
240            .content
241            .as_ref()
242            .map(|c| vec![completion::AssistantContent::text(c)])
243            .unwrap_or_default();
244
245        content.extend(message.tool_calls.iter().map(|call| {
246            completion::AssistantContent::tool_call(
247                &call.function.name,
248                &call.function.name,
249                call.function.arguments.clone(),
250            )
251        }));
252
253        let choice = OneOrMany::many(content).map_err(|_| {
254            CompletionError::ResponseError(
255                "Response contained no message or tool call (empty)".to_owned(),
256            )
257        })?;
258        let usage = response
259            .usage
260            .as_ref()
261            .map(|usage| completion::Usage {
262                input_tokens: usage.prompt_tokens as u64,
263                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
264                total_tokens: usage.total_tokens as u64,
265                cached_input_tokens: 0,
266            })
267            .unwrap_or_default();
268
269        Ok(completion::CompletionResponse {
270            choice,
271            usage,
272            raw_response: response,
273        })
274    }
275}
276
277#[derive(Debug, Deserialize, Serialize)]
278pub struct Choice {
279    pub index: usize,
280    pub message: Message,
281    pub logprobs: Option<serde_json::Value>,
282    pub finish_reason: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Message {
287    pub role: String,
288    pub content: Option<String>,
289    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
290    pub tool_calls: Vec<openai::ToolCall>,
291}
292
293impl Message {
294    fn system(preamble: &str) -> Self {
295        Self {
296            role: "system".to_string(),
297            content: Some(preamble.to_string()),
298            tool_calls: Vec::new(),
299        }
300    }
301}
302
303impl TryFrom<Message> for message::Message {
304    type Error = message::MessageError;
305
306    fn try_from(message: Message) -> Result<Self, Self::Error> {
307        let tool_calls: Vec<message::ToolCall> = message
308            .tool_calls
309            .into_iter()
310            .map(|tool_call| tool_call.into())
311            .collect();
312
313        match message.role.as_str() {
314            "user" => Ok(Self::User {
315                content: OneOrMany::one(
316                    message
317                        .content
318                        .map(|content| message::UserContent::text(&content))
319                        .ok_or_else(|| {
320                            message::MessageError::ConversionError("Empty user message".to_string())
321                        })?,
322                ),
323            }),
324            "assistant" => Ok(Self::Assistant {
325                id: None,
326                content: OneOrMany::many(
327                    tool_calls
328                        .into_iter()
329                        .map(message::AssistantContent::ToolCall)
330                        .chain(
331                            message
332                                .content
333                                .map(|content| message::AssistantContent::text(&content))
334                                .into_iter(),
335                        ),
336                )
337                .map_err(|_| {
338                    message::MessageError::ConversionError("Empty assistant message".to_string())
339                })?,
340            }),
341            _ => Err(message::MessageError::ConversionError(format!(
342                "Unknown role: {}",
343                message.role
344            ))),
345        }
346    }
347}
348
349impl TryFrom<message::Message> for Message {
350    type Error = message::MessageError;
351
352    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353        match message {
354            message::Message::User { content } => Ok(Self {
355                role: "user".to_string(),
356                content: content.iter().find_map(|c| match c {
357                    message::UserContent::Text(text) => Some(text.text.clone()),
358                    _ => None,
359                }),
360                tool_calls: vec![],
361            }),
362            message::Message::Assistant { content, .. } => {
363                let mut text_content: Option<String> = None;
364                let mut tool_calls = vec![];
365
366                for c in content.iter() {
367                    match c {
368                        message::AssistantContent::Text(text) => {
369                            text_content = Some(
370                                text_content
371                                    .map(|mut existing| {
372                                        existing.push('\n');
373                                        existing.push_str(&text.text);
374                                        existing
375                                    })
376                                    .unwrap_or_else(|| text.text.clone()),
377                            );
378                        }
379                        message::AssistantContent::ToolCall(tool_call) => {
380                            tool_calls.push(tool_call.clone().into());
381                        }
382                        message::AssistantContent::Reasoning(_) => {
383                            return Err(MessageError::ConversionError(
384                                "Galadriel currently doesn't support reasoning.".into(),
385                            ));
386                        }
387                        message::AssistantContent::Image(_) => {
388                            return Err(MessageError::ConversionError(
389                                "Galadriel currently doesn't support images.".into(),
390                            ));
391                        }
392                    }
393                }
394
395                Ok(Self {
396                    role: "assistant".to_string(),
397                    content: text_content,
398                    tool_calls,
399                })
400            }
401        }
402    }
403}
404
405#[derive(Clone, Debug, Deserialize, Serialize)]
406pub struct ToolDefinition {
407    pub r#type: String,
408    pub function: completion::ToolDefinition,
409}
410
411impl From<completion::ToolDefinition> for ToolDefinition {
412    fn from(tool: completion::ToolDefinition) -> Self {
413        Self {
414            r#type: "function".into(),
415            function: tool,
416        }
417    }
418}
419
420#[derive(Debug, Deserialize)]
421pub struct Function {
422    pub name: String,
423    pub arguments: String,
424}
425
426#[derive(Debug, Serialize, Deserialize)]
427pub(super) struct GaladrielCompletionRequest {
428    model: String,
429    pub messages: Vec<Message>,
430    #[serde(skip_serializing_if = "Option::is_none")]
431    temperature: Option<f64>,
432    #[serde(skip_serializing_if = "Vec::is_empty")]
433    tools: Vec<ToolDefinition>,
434    #[serde(skip_serializing_if = "Option::is_none")]
435    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
436    #[serde(flatten, skip_serializing_if = "Option::is_none")]
437    pub additional_params: Option<serde_json::Value>,
438}
439
440impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
441    type Error = CompletionError;
442
443    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
444        // Build up the order of messages (context, chat_history, prompt)
445        let mut partial_history = vec![];
446        if let Some(docs) = req.normalized_documents() {
447            partial_history.push(docs);
448        }
449        partial_history.extend(req.chat_history);
450
451        // Add preamble to chat history (if available)
452        let mut full_history: Vec<Message> = match &req.preamble {
453            Some(preamble) => vec![Message::system(preamble)],
454            None => vec![],
455        };
456
457        // Convert and extend the rest of the history
458        full_history.extend(
459            partial_history
460                .into_iter()
461                .map(message::Message::try_into)
462                .collect::<Result<Vec<Message>, _>>()?,
463        );
464
465        let tool_choice = req
466            .tool_choice
467            .clone()
468            .map(crate::providers::openai::completion::ToolChoice::try_from)
469            .transpose()?;
470
471        Ok(Self {
472            model: model.to_string(),
473            messages: full_history,
474            temperature: req.temperature,
475            tools: req
476                .tools
477                .clone()
478                .into_iter()
479                .map(ToolDefinition::from)
480                .collect::<Vec<_>>(),
481            tool_choice,
482            additional_params: req.additional_params,
483        })
484    }
485}
486
487#[derive(Clone)]
488pub struct CompletionModel<T = reqwest::Client> {
489    client: Client<T>,
490    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
491    pub model: String,
492}
493
494impl<T> CompletionModel<T>
495where
496    T: HttpClientExt,
497{
498    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
499        Self {
500            client,
501            model: model.into(),
502        }
503    }
504
505    pub fn with_model(client: Client<T>, model: &str) -> Self {
506        Self {
507            client,
508            model: model.into(),
509        }
510    }
511}
512
513impl<T> completion::CompletionModel for CompletionModel<T>
514where
515    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
516{
517    type Response = CompletionResponse;
518    type StreamingResponse = openai::StreamingCompletionResponse;
519
520    type Client = Client<T>;
521
522    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
523        Self::new(client.clone(), model.into())
524    }
525
526    async fn completion(
527        &self,
528        completion_request: CompletionRequest,
529    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
530        let span = if tracing::Span::current().is_disabled() {
531            info_span!(
532                target: "rig::completions",
533                "chat",
534                gen_ai.operation.name = "chat",
535                gen_ai.provider.name = "galadriel",
536                gen_ai.request.model = self.model,
537                gen_ai.system_instructions = tracing::field::Empty,
538                gen_ai.response.id = tracing::field::Empty,
539                gen_ai.response.model = tracing::field::Empty,
540                gen_ai.usage.output_tokens = tracing::field::Empty,
541                gen_ai.usage.input_tokens = tracing::field::Empty,
542            )
543        } else {
544            tracing::Span::current()
545        };
546
547        span.record("gen_ai.system_instructions", &completion_request.preamble);
548
549        let request =
550            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
551
552        if enabled!(tracing::Level::TRACE) {
553            tracing::trace!(target: "rig::completions",
554                "Galadriel completion request: {}",
555                serde_json::to_string_pretty(&request)?
556            );
557        }
558
559        let body = serde_json::to_vec(&request)?;
560
561        let req = self
562            .client
563            .post("/chat/completions")?
564            .body(body)
565            .map_err(http_client::Error::from)?;
566
567        async move {
568            let response = self.client.send(req).await?;
569
570            if response.status().is_success() {
571                let t = http_client::text(response).await?;
572
573                if enabled!(tracing::Level::TRACE) {
574                    tracing::trace!(target: "rig::completions",
575                        "Galadriel completion response: {}",
576                        serde_json::to_string_pretty(&t)?
577                    );
578                }
579
580                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
581                    ApiResponse::Ok(response) => {
582                        let span = tracing::Span::current();
583                        span.record("gen_ai.response.id", response.id.clone());
584                        span.record("gen_ai.response.model_name", response.model.clone());
585                        if let Some(ref usage) = response.usage {
586                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
587                            span.record(
588                                "gen_ai.usage.output_tokens",
589                                usage.total_tokens - usage.prompt_tokens,
590                            );
591                        }
592                        response.try_into()
593                    }
594                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
595                }
596            } else {
597                let text = http_client::text(response).await?;
598
599                Err(CompletionError::ProviderError(text))
600            }
601        }
602        .instrument(span)
603        .await
604    }
605
606    async fn stream(
607        &self,
608        completion_request: CompletionRequest,
609    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
610        let preamble = completion_request.preamble.clone();
611        let mut request =
612            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
613
614        let params = json_utils::merge(
615            request.additional_params.unwrap_or(serde_json::json!({})),
616            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
617        );
618
619        request.additional_params = Some(params);
620
621        let body = serde_json::to_vec(&request)?;
622
623        let req = self
624            .client
625            .post("/chat/completions")?
626            .body(body)
627            .map_err(http_client::Error::from)?;
628
629        let span = if tracing::Span::current().is_disabled() {
630            info_span!(
631                target: "rig::completions",
632                "chat_streaming",
633                gen_ai.operation.name = "chat_streaming",
634                gen_ai.provider.name = "galadriel",
635                gen_ai.request.model = self.model,
636                gen_ai.system_instructions = preamble,
637                gen_ai.response.id = tracing::field::Empty,
638                gen_ai.response.model = tracing::field::Empty,
639                gen_ai.usage.output_tokens = tracing::field::Empty,
640                gen_ai.usage.input_tokens = tracing::field::Empty,
641                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
642                gen_ai.output.messages = tracing::field::Empty,
643            )
644        } else {
645            tracing::Span::current()
646        };
647
648        send_compatible_streaming_request(self.client.clone(), req)
649            .instrument(span)
650            .await
651    }
652}