Skip to main content

rig/providers/
galadriel.rs

1//! Galadriel API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::galadriel;
6//!
7//! let client = galadriel::Client::new("YOUR_API_KEY", None);
8//! // to use a fine-tuned model
9//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
10//!
11//! let gpt4o = client.completion_model(galadriel::GPT_4O);
12//! ```
13use super::openai;
14use crate::client::{
15    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16    ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30// ================================================================
31// Main Galadriel Client
32// ================================================================
33const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37    fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42    fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48    type Builder = GaladrielBuilder;
49
50    /// There is currently no way to verify a Galadriel api key without consuming tokens
51    const VERIFY_PATH: &'static str = "";
52
53    fn build<H>(
54        builder: &crate::client::ClientBuilder<
55            Self::Builder,
56            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
57            H,
58        >,
59    ) -> http_client::Result<Self> {
60        let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
61
62        Ok(Self { fine_tune_api_key })
63    }
64}
65
66impl<H> Capabilities<H> for GaladrielExt {
67    type Completion = Capable<CompletionModel<H>>;
68    type Embeddings = Nothing;
69    type Transcription = Nothing;
70    type ModelListing = Nothing;
71    #[cfg(feature = "image")]
72    type ImageGeneration = Nothing;
73    #[cfg(feature = "audio")]
74    type AudioGeneration = Nothing;
75}
76
77impl DebugExt for GaladrielExt {
78    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
79        std::iter::once((
80            "fine_tune_api_key",
81            (&self.fine_tune_api_key as &dyn std::fmt::Debug),
82        ))
83    }
84}
85
86impl ProviderBuilder for GaladrielBuilder {
87    type Output = GaladrielExt;
88    type ApiKey = GaladrielApiKey;
89
90    const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
94pub type ClientBuilder<H = reqwest::Client> =
95    client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
96
97impl<T> ClientBuilder<T> {
98    pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
99    where
100        S: AsRef<str>,
101    {
102        *self.ext_mut() = GaladrielBuilder {
103            fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
104        };
105
106        self
107    }
108}
109
110impl ProviderClient for Client {
111    type Input = (String, Option<String>);
112
113    /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
114    /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
115    /// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
116    fn from_env() -> Self {
117        let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
118        let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
119
120        let mut builder = Self::builder().api_key(api_key);
121
122        if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
123            builder = builder.fine_tune_api_key(fine_tune_api_key);
124        }
125
126        builder.build().unwrap()
127    }
128
129    fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
130        let mut builder = Self::builder().api_key(api_key);
131
132        if let Some(fine_tune_key) = fine_tune_api_key {
133            builder = builder.fine_tune_api_key(fine_tune_key)
134        }
135
136        builder.build().unwrap()
137    }
138}
139
140#[derive(Debug, Deserialize)]
141struct ApiErrorResponse {
142    message: String,
143}
144
145#[derive(Debug, Deserialize)]
146#[serde(untagged)]
147enum ApiResponse<T> {
148    Ok(T),
149    Err(ApiErrorResponse),
150}
151
152#[derive(Clone, Debug, Deserialize, Serialize)]
153pub struct Usage {
154    pub prompt_tokens: usize,
155    pub total_tokens: usize,
156}
157
158impl std::fmt::Display for Usage {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        write!(
161            f,
162            "Prompt tokens: {} Total tokens: {}",
163            self.prompt_tokens, self.total_tokens
164        )
165    }
166}
167
168// ================================================================
169// Galadriel Completion API
170// ================================================================
171
172/// `o1-preview` completion model
173pub const O1_PREVIEW: &str = "o1-preview";
174/// `o1-preview-2024-09-12` completion model
175pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
176/// `o1-mini completion model
177pub const O1_MINI: &str = "o1-mini";
178/// `o1-mini-2024-09-12` completion model
179pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
180/// `gpt-4o` completion model
181pub const GPT_4O: &str = "gpt-4o";
182/// `gpt-4o-2024-05-13` completion model
183pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
184/// `gpt-4-turbo` completion model
185pub const GPT_4_TURBO: &str = "gpt-4-turbo";
186/// `gpt-4-turbo-2024-04-09` completion model
187pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
188/// `gpt-4-turbo-preview` completion model
189pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
190/// `gpt-4-0125-preview` completion model
191pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
192/// `gpt-4-1106-preview` completion model
193pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
194/// `gpt-4-vision-preview` completion model
195pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
196/// `gpt-4-1106-vision-preview` completion model
197pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
198/// `gpt-4` completion model
199pub const GPT_4: &str = "gpt-4";
200/// `gpt-4-0613` completion model
201pub const GPT_4_0613: &str = "gpt-4-0613";
202/// `gpt-4-32k` completion model
203pub const GPT_4_32K: &str = "gpt-4-32k";
204/// `gpt-4-32k-0613` completion model
205pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
206/// `gpt-3.5-turbo` completion model
207pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
208/// `gpt-3.5-turbo-0125` completion model
209pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
210/// `gpt-3.5-turbo-1106` completion model
211pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
212/// `gpt-3.5-turbo-instruct` completion model
213pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
214
215#[derive(Debug, Deserialize, Serialize)]
216pub struct CompletionResponse {
217    pub id: String,
218    pub object: String,
219    pub created: u64,
220    pub model: String,
221    pub system_fingerprint: Option<String>,
222    pub choices: Vec<Choice>,
223    pub usage: Option<Usage>,
224}
225
226impl From<ApiErrorResponse> for CompletionError {
227    fn from(err: ApiErrorResponse) -> Self {
228        CompletionError::ProviderError(err.message)
229    }
230}
231
232impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
233    type Error = CompletionError;
234
235    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
236        let Choice { message, .. } = response.choices.first().ok_or_else(|| {
237            CompletionError::ResponseError("Response contained no choices".to_owned())
238        })?;
239
240        let mut content = message
241            .content
242            .as_ref()
243            .map(|c| vec![completion::AssistantContent::text(c)])
244            .unwrap_or_default();
245
246        content.extend(message.tool_calls.iter().map(|call| {
247            completion::AssistantContent::tool_call(
248                &call.function.name,
249                &call.function.name,
250                call.function.arguments.clone(),
251            )
252        }));
253
254        let choice = OneOrMany::many(content).map_err(|_| {
255            CompletionError::ResponseError(
256                "Response contained no message or tool call (empty)".to_owned(),
257            )
258        })?;
259        let usage = response
260            .usage
261            .as_ref()
262            .map(|usage| completion::Usage {
263                input_tokens: usage.prompt_tokens as u64,
264                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
265                total_tokens: usage.total_tokens as u64,
266                cached_input_tokens: 0,
267            })
268            .unwrap_or_default();
269
270        Ok(completion::CompletionResponse {
271            choice,
272            usage,
273            raw_response: response,
274            message_id: None,
275        })
276    }
277}
278
279#[derive(Debug, Deserialize, Serialize)]
280pub struct Choice {
281    pub index: usize,
282    pub message: Message,
283    pub logprobs: Option<serde_json::Value>,
284    pub finish_reason: String,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct Message {
289    pub role: String,
290    pub content: Option<String>,
291    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
292    pub tool_calls: Vec<openai::ToolCall>,
293}
294
295impl Message {
296    fn system(preamble: &str) -> Self {
297        Self {
298            role: "system".to_string(),
299            content: Some(preamble.to_string()),
300            tool_calls: Vec::new(),
301        }
302    }
303}
304
305impl TryFrom<Message> for message::Message {
306    type Error = message::MessageError;
307
308    fn try_from(message: Message) -> Result<Self, Self::Error> {
309        let tool_calls: Vec<message::ToolCall> = message
310            .tool_calls
311            .into_iter()
312            .map(|tool_call| tool_call.into())
313            .collect();
314
315        match message.role.as_str() {
316            "user" => Ok(Self::User {
317                content: OneOrMany::one(
318                    message
319                        .content
320                        .map(|content| message::UserContent::text(&content))
321                        .ok_or_else(|| {
322                            message::MessageError::ConversionError("Empty user message".to_string())
323                        })?,
324                ),
325            }),
326            "assistant" => Ok(Self::Assistant {
327                id: None,
328                content: OneOrMany::many(
329                    tool_calls
330                        .into_iter()
331                        .map(message::AssistantContent::ToolCall)
332                        .chain(
333                            message
334                                .content
335                                .map(|content| message::AssistantContent::text(&content))
336                                .into_iter(),
337                        ),
338                )
339                .map_err(|_| {
340                    message::MessageError::ConversionError("Empty assistant message".to_string())
341                })?,
342            }),
343            _ => Err(message::MessageError::ConversionError(format!(
344                "Unknown role: {}",
345                message.role
346            ))),
347        }
348    }
349}
350
351impl TryFrom<message::Message> for Message {
352    type Error = message::MessageError;
353
354    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
355        match message {
356            message::Message::User { content } => Ok(Self {
357                role: "user".to_string(),
358                content: content.iter().find_map(|c| match c {
359                    message::UserContent::Text(text) => Some(text.text.clone()),
360                    _ => None,
361                }),
362                tool_calls: vec![],
363            }),
364            message::Message::Assistant { content, .. } => {
365                let mut text_content: Option<String> = None;
366                let mut tool_calls = vec![];
367
368                for c in content.iter() {
369                    match c {
370                        message::AssistantContent::Text(text) => {
371                            text_content = Some(
372                                text_content
373                                    .map(|mut existing| {
374                                        existing.push('\n');
375                                        existing.push_str(&text.text);
376                                        existing
377                                    })
378                                    .unwrap_or_else(|| text.text.clone()),
379                            );
380                        }
381                        message::AssistantContent::ToolCall(tool_call) => {
382                            tool_calls.push(tool_call.clone().into());
383                        }
384                        message::AssistantContent::Reasoning(_) => {
385                            return Err(MessageError::ConversionError(
386                                "Galadriel currently doesn't support reasoning.".into(),
387                            ));
388                        }
389                        message::AssistantContent::Image(_) => {
390                            return Err(MessageError::ConversionError(
391                                "Galadriel currently doesn't support images.".into(),
392                            ));
393                        }
394                    }
395                }
396
397                Ok(Self {
398                    role: "assistant".to_string(),
399                    content: text_content,
400                    tool_calls,
401                })
402            }
403        }
404    }
405}
406
407#[derive(Clone, Debug, Deserialize, Serialize)]
408pub struct ToolDefinition {
409    pub r#type: String,
410    pub function: completion::ToolDefinition,
411}
412
413impl From<completion::ToolDefinition> for ToolDefinition {
414    fn from(tool: completion::ToolDefinition) -> Self {
415        Self {
416            r#type: "function".into(),
417            function: tool,
418        }
419    }
420}
421
422#[derive(Debug, Deserialize)]
423pub struct Function {
424    pub name: String,
425    pub arguments: String,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429pub(super) struct GaladrielCompletionRequest {
430    model: String,
431    pub messages: Vec<Message>,
432    #[serde(skip_serializing_if = "Option::is_none")]
433    temperature: Option<f64>,
434    #[serde(skip_serializing_if = "Vec::is_empty")]
435    tools: Vec<ToolDefinition>,
436    #[serde(skip_serializing_if = "Option::is_none")]
437    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
438    #[serde(flatten, skip_serializing_if = "Option::is_none")]
439    pub additional_params: Option<serde_json::Value>,
440}
441
442impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
443    type Error = CompletionError;
444
445    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
446        if req.output_schema.is_some() {
447            tracing::warn!("Structured outputs currently not supported for Galadriel");
448        }
449        let model = req.model.clone().unwrap_or_else(|| model.to_string());
450        // Build up the order of messages (context, chat_history, prompt)
451        let mut partial_history = vec![];
452        if let Some(docs) = req.normalized_documents() {
453            partial_history.push(docs);
454        }
455        partial_history.extend(req.chat_history);
456
457        // Add preamble to chat history (if available)
458        let mut full_history: Vec<Message> = match &req.preamble {
459            Some(preamble) => vec![Message::system(preamble)],
460            None => vec![],
461        };
462
463        // Convert and extend the rest of the history
464        full_history.extend(
465            partial_history
466                .into_iter()
467                .map(message::Message::try_into)
468                .collect::<Result<Vec<Message>, _>>()?,
469        );
470
471        let tool_choice = req
472            .tool_choice
473            .clone()
474            .map(crate::providers::openai::completion::ToolChoice::try_from)
475            .transpose()?;
476
477        Ok(Self {
478            model: model.to_string(),
479            messages: full_history,
480            temperature: req.temperature,
481            tools: req
482                .tools
483                .clone()
484                .into_iter()
485                .map(ToolDefinition::from)
486                .collect::<Vec<_>>(),
487            tool_choice,
488            additional_params: req.additional_params,
489        })
490    }
491}
492
493#[derive(Clone)]
494pub struct CompletionModel<T = reqwest::Client> {
495    client: Client<T>,
496    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
497    pub model: String,
498}
499
500impl<T> CompletionModel<T>
501where
502    T: HttpClientExt,
503{
504    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
505        Self {
506            client,
507            model: model.into(),
508        }
509    }
510
511    pub fn with_model(client: Client<T>, model: &str) -> Self {
512        Self {
513            client,
514            model: model.into(),
515        }
516    }
517}
518
519impl<T> completion::CompletionModel for CompletionModel<T>
520where
521    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
522{
523    type Response = CompletionResponse;
524    type StreamingResponse = openai::StreamingCompletionResponse;
525
526    type Client = Client<T>;
527
528    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
529        Self::new(client.clone(), model.into())
530    }
531
532    async fn completion(
533        &self,
534        completion_request: CompletionRequest,
535    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
536        let span = if tracing::Span::current().is_disabled() {
537            info_span!(
538                target: "rig::completions",
539                "chat",
540                gen_ai.operation.name = "chat",
541                gen_ai.provider.name = "galadriel",
542                gen_ai.request.model = self.model,
543                gen_ai.system_instructions = tracing::field::Empty,
544                gen_ai.response.id = tracing::field::Empty,
545                gen_ai.response.model = tracing::field::Empty,
546                gen_ai.usage.output_tokens = tracing::field::Empty,
547                gen_ai.usage.input_tokens = tracing::field::Empty,
548            )
549        } else {
550            tracing::Span::current()
551        };
552
553        span.record("gen_ai.system_instructions", &completion_request.preamble);
554
555        let request =
556            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
557
558        if enabled!(tracing::Level::TRACE) {
559            tracing::trace!(target: "rig::completions",
560                "Galadriel completion request: {}",
561                serde_json::to_string_pretty(&request)?
562            );
563        }
564
565        let body = serde_json::to_vec(&request)?;
566
567        let req = self
568            .client
569            .post("/chat/completions")?
570            .body(body)
571            .map_err(http_client::Error::from)?;
572
573        async move {
574            let response = self.client.send(req).await?;
575
576            if response.status().is_success() {
577                let t = http_client::text(response).await?;
578
579                if enabled!(tracing::Level::TRACE) {
580                    tracing::trace!(target: "rig::completions",
581                        "Galadriel completion response: {}",
582                        serde_json::to_string_pretty(&t)?
583                    );
584                }
585
586                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
587                    ApiResponse::Ok(response) => {
588                        let span = tracing::Span::current();
589                        span.record("gen_ai.response.id", response.id.clone());
590                        span.record("gen_ai.response.model_name", response.model.clone());
591                        if let Some(ref usage) = response.usage {
592                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
593                            span.record(
594                                "gen_ai.usage.output_tokens",
595                                usage.total_tokens - usage.prompt_tokens,
596                            );
597                        }
598                        response.try_into()
599                    }
600                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
601                }
602            } else {
603                let text = http_client::text(response).await?;
604
605                Err(CompletionError::ProviderError(text))
606            }
607        }
608        .instrument(span)
609        .await
610    }
611
612    async fn stream(
613        &self,
614        completion_request: CompletionRequest,
615    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
616        let preamble = completion_request.preamble.clone();
617        let mut request =
618            GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
619
620        let params = json_utils::merge(
621            request.additional_params.unwrap_or(serde_json::json!({})),
622            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
623        );
624
625        request.additional_params = Some(params);
626
627        let body = serde_json::to_vec(&request)?;
628
629        let req = self
630            .client
631            .post("/chat/completions")?
632            .body(body)
633            .map_err(http_client::Error::from)?;
634
635        let span = if tracing::Span::current().is_disabled() {
636            info_span!(
637                target: "rig::completions",
638                "chat_streaming",
639                gen_ai.operation.name = "chat_streaming",
640                gen_ai.provider.name = "galadriel",
641                gen_ai.request.model = self.model,
642                gen_ai.system_instructions = preamble,
643                gen_ai.response.id = tracing::field::Empty,
644                gen_ai.response.model = tracing::field::Empty,
645                gen_ai.usage.output_tokens = tracing::field::Empty,
646                gen_ai.usage.input_tokens = tracing::field::Empty,
647                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
648                gen_ai.output.messages = tracing::field::Empty,
649            )
650        } else {
651            tracing::Span::current()
652        };
653
654        send_compatible_streaming_request(self.client.clone(), req)
655            .instrument(span)
656            .await
657    }
658}
659#[cfg(test)]
660mod tests {
661    #[test]
662    fn test_client_initialization() {
663        let _client: crate::providers::galadriel::Client =
664            crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
665        let _client_from_builder: crate::providers::galadriel::Client =
666            crate::providers::galadriel::Client::builder()
667                .api_key("dummy-key")
668                .build()
669                .expect("Client::builder() failed");
670    }
671}