1use super::openai;
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::http_client::{self, HttpClientExt};
25use crate::message::MessageError;
26use crate::providers::openai::send_compatible_streaming_request;
27use crate::streaming::StreamingCompletionResponse;
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34use tracing::{Instrument, enabled, info_span};
35
36const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
40
41#[derive(Debug, Default, Clone)]
42pub struct GaladrielExt {
43 fine_tune_api_key: Option<String>,
44}
45
46#[derive(Debug, Default, Clone)]
47pub struct GaladrielBuilder {
48 fine_tune_api_key: Option<String>,
49}
50
51type GaladrielApiKey = BearerAuth;
52
53impl Provider for GaladrielExt {
54 type Builder = GaladrielBuilder;
55
56 const VERIFY_PATH: &'static str = "";
58}
59
60impl<H> Capabilities<H> for GaladrielExt {
61 type Completion = Capable<CompletionModel<H>>;
62 type Embeddings = Nothing;
63 type Transcription = Nothing;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Nothing;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Nothing;
69 type Rerank = Nothing;
70}
71
72impl DebugExt for GaladrielExt {
73 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
74 std::iter::once((
75 "fine_tune_api_key",
76 (&self.fine_tune_api_key as &dyn std::fmt::Debug),
77 ))
78 }
79}
80
81impl ProviderBuilder for GaladrielBuilder {
82 type Extension<H>
83 = GaladrielExt
84 where
85 H: HttpClientExt;
86 type ApiKey = GaladrielApiKey;
87
88 const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
89
90 fn build<H>(
91 builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
92 ) -> http_client::Result<Self::Extension<H>>
93 where
94 H: HttpClientExt,
95 {
96 let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
97
98 Ok(GaladrielExt { fine_tune_api_key })
99 }
100}
101
102pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
103pub type ClientBuilder<H = crate::markers::Missing> =
104 client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
105
106impl<T> ClientBuilder<T> {
107 pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
108 where
109 S: AsRef<str>,
110 {
111 *self.ext_mut() = GaladrielBuilder {
112 fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
113 };
114
115 self
116 }
117}
118
119impl ProviderClient for Client {
120 type Input = (String, Option<String>);
121 type Error = crate::client::ProviderClientError;
122
123 fn from_env() -> Result<Self, Self::Error> {
126 let api_key = crate::client::required_env_var("GALADRIEL_API_KEY")?;
127 let fine_tune_api_key = crate::client::optional_env_var("GALADRIEL_FINE_TUNE_API_KEY")?;
128
129 let mut builder = Self::builder().api_key(api_key);
130
131 if let Some(fine_tune_api_key) = fine_tune_api_key {
132 builder = builder.fine_tune_api_key(fine_tune_api_key);
133 }
134
135 builder.build().map_err(Into::into)
136 }
137
138 fn from_val((api_key, fine_tune_api_key): Self::Input) -> Result<Self, Self::Error> {
139 let mut builder = Self::builder().api_key(api_key);
140
141 if let Some(fine_tune_key) = fine_tune_api_key {
142 builder = builder.fine_tune_api_key(fine_tune_key)
143 }
144
145 builder.build().map_err(Into::into)
146 }
147}
148
149#[derive(Debug, Deserialize)]
150struct ApiErrorResponse {
151 message: String,
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156enum ApiResponse<T> {
157 Ok(T),
158 Err(ApiErrorResponse),
159}
160
161#[derive(Clone, Debug, Deserialize, Serialize)]
162pub struct Usage {
163 pub prompt_tokens: usize,
164 pub total_tokens: usize,
165}
166
167impl std::fmt::Display for Usage {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 write!(
170 f,
171 "Prompt tokens: {} Total tokens: {}",
172 self.prompt_tokens, self.total_tokens
173 )
174 }
175}
176
177pub const O1_PREVIEW: &str = "o1-preview";
183pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
185pub const O1_MINI: &str = "o1-mini";
187pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
189pub const GPT_4O: &str = "gpt-4o";
191pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
193pub const GPT_4_TURBO: &str = "gpt-4-turbo";
195pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
197pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
199pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
201pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
203pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
205pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
207pub const GPT_4: &str = "gpt-4";
209pub const GPT_4_0613: &str = "gpt-4-0613";
211pub const GPT_4_32K: &str = "gpt-4-32k";
213pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
215pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
217pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
219pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
221pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
223
224#[derive(Debug, Deserialize, Serialize)]
225pub struct CompletionResponse {
226 pub id: String,
227 pub object: String,
228 pub created: u64,
229 pub model: String,
230 pub system_fingerprint: Option<String>,
231 pub choices: Vec<Choice>,
232 pub usage: Option<Usage>,
233}
234
235impl From<ApiErrorResponse> for CompletionError {
236 fn from(err: ApiErrorResponse) -> Self {
237 CompletionError::ProviderError(err.message)
238 }
239}
240
241impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
242 type Error = CompletionError;
243
244 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
245 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
246 CompletionError::ResponseError("Response contained no choices".to_owned())
247 })?;
248
249 let mut content = message
250 .content
251 .as_ref()
252 .map(|c| vec![completion::AssistantContent::text(c)])
253 .unwrap_or_default();
254
255 content.extend(message.tool_calls.iter().map(|call| {
256 completion::AssistantContent::tool_call(
257 &call.function.name,
258 &call.function.name,
259 call.function.arguments.clone(),
260 )
261 }));
262
263 let choice = OneOrMany::many(content).map_err(|_| {
264 CompletionError::ResponseError(
265 "Response contained no message or tool call (empty)".to_owned(),
266 )
267 })?;
268 let usage = response
269 .usage
270 .as_ref()
271 .map(|usage| completion::Usage {
272 input_tokens: usage.prompt_tokens as u64,
273 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
274 total_tokens: usage.total_tokens as u64,
275 cached_input_tokens: 0,
276 cache_creation_input_tokens: 0,
277 tool_use_prompt_tokens: 0,
278 reasoning_tokens: 0,
279 })
280 .unwrap_or_default();
281
282 Ok(completion::CompletionResponse {
283 choice,
284 usage,
285 raw_response: response,
286 message_id: None,
287 })
288 }
289}
290
291#[derive(Debug, Deserialize, Serialize)]
292pub struct Choice {
293 pub index: usize,
294 pub message: Message,
295 pub logprobs: Option<serde_json::Value>,
296 pub finish_reason: String,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub struct Message {
301 pub role: String,
302 pub content: Option<String>,
303 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
304 pub tool_calls: Vec<openai::ToolCall>,
305}
306
307impl Message {
308 fn system(preamble: &str) -> Self {
309 Self {
310 role: "system".to_string(),
311 content: Some(preamble.to_string()),
312 tool_calls: Vec::new(),
313 }
314 }
315}
316
317impl TryFrom<Message> for message::Message {
318 type Error = message::MessageError;
319
320 fn try_from(message: Message) -> Result<Self, Self::Error> {
321 let tool_calls: Vec<message::ToolCall> = message
322 .tool_calls
323 .into_iter()
324 .map(|tool_call| tool_call.into())
325 .collect();
326
327 match message.role.as_str() {
328 "user" => Ok(Self::User {
329 content: OneOrMany::one(
330 message
331 .content
332 .map(|content| message::UserContent::text(&content))
333 .ok_or_else(|| {
334 message::MessageError::ConversionError("Empty user message".to_string())
335 })?,
336 ),
337 }),
338 "assistant" => Ok(Self::Assistant {
339 id: None,
340 content: OneOrMany::many(
341 tool_calls
342 .into_iter()
343 .map(message::AssistantContent::ToolCall)
344 .chain(
345 message
346 .content
347 .map(|content| message::AssistantContent::text(&content))
348 .into_iter(),
349 ),
350 )
351 .map_err(|_| {
352 message::MessageError::ConversionError("Empty assistant message".to_string())
353 })?,
354 }),
355 _ => Err(message::MessageError::ConversionError(format!(
356 "Unknown role: {}",
357 message.role
358 ))),
359 }
360 }
361}
362
363impl TryFrom<message::Message> for Message {
364 type Error = message::MessageError;
365
366 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
367 match message {
368 message::Message::System { content } => Ok(Self {
369 role: "system".to_string(),
370 content: Some(content),
371 tool_calls: vec![],
372 }),
373 message::Message::User { content } => Ok(Self {
374 role: "user".to_string(),
375 content: content.iter().find_map(|c| match c {
376 message::UserContent::Text(text) => Some(text.text.clone()),
377 _ => None,
378 }),
379 tool_calls: vec![],
380 }),
381 message::Message::Assistant { content, .. } => {
382 let mut text_content: Option<String> = None;
383 let mut tool_calls = vec![];
384
385 for c in content.iter() {
386 match c {
387 message::AssistantContent::Text(text) => {
388 text_content = Some(
389 text_content
390 .map(|mut existing| {
391 existing.push('\n');
392 existing.push_str(&text.text);
393 existing
394 })
395 .unwrap_or_else(|| text.text.clone()),
396 );
397 }
398 message::AssistantContent::ToolCall(tool_call) => {
399 tool_calls.push(tool_call.clone().into());
400 }
401 message::AssistantContent::Reasoning(_) => {
402 return Err(MessageError::ConversionError(
403 "Galadriel currently doesn't support reasoning.".into(),
404 ));
405 }
406 message::AssistantContent::Image(_) => {
407 return Err(MessageError::ConversionError(
408 "Galadriel currently doesn't support images.".into(),
409 ));
410 }
411 }
412 }
413
414 Ok(Self {
415 role: "assistant".to_string(),
416 content: text_content,
417 tool_calls,
418 })
419 }
420 }
421 }
422}
423
424#[derive(Clone, Debug, Deserialize, Serialize)]
425pub struct ToolDefinition {
426 pub r#type: String,
427 pub function: completion::ToolDefinition,
428}
429
430impl From<completion::ToolDefinition> for ToolDefinition {
431 fn from(tool: completion::ToolDefinition) -> Self {
432 Self {
433 r#type: "function".into(),
434 function: tool,
435 }
436 }
437}
438
439#[derive(Debug, Deserialize)]
440pub struct Function {
441 pub name: String,
442 pub arguments: String,
443}
444
445#[derive(Debug, Serialize, Deserialize)]
446pub(super) struct GaladrielCompletionRequest {
447 model: String,
448 pub messages: Vec<Message>,
449 #[serde(skip_serializing_if = "Option::is_none")]
450 temperature: Option<f64>,
451 #[serde(skip_serializing_if = "Vec::is_empty")]
452 tools: Vec<ToolDefinition>,
453 #[serde(skip_serializing_if = "Option::is_none")]
454 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
455 #[serde(flatten, skip_serializing_if = "Option::is_none")]
456 pub additional_params: Option<serde_json::Value>,
457}
458
459impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
460 type Error = CompletionError;
461
462 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
463 let chat_history = req.chat_history_with_documents();
464 if req.output_schema.is_some() {
465 tracing::warn!("Structured outputs currently not supported for Galadriel");
466 }
467 let model = req.model.clone().unwrap_or_else(|| model.to_string());
468 let mut partial_history = vec![];
470 partial_history.extend(chat_history);
471
472 let mut full_history: Vec<Message> = match &req.preamble {
474 Some(preamble) => vec![Message::system(preamble)],
475 None => vec![],
476 };
477
478 full_history.extend(
480 partial_history
481 .into_iter()
482 .map(message::Message::try_into)
483 .collect::<Result<Vec<Message>, _>>()?,
484 );
485
486 let tool_choice = req
487 .tool_choice
488 .clone()
489 .map(crate::providers::openai::completion::ToolChoice::try_from)
490 .transpose()?;
491
492 Ok(Self {
493 model: model.to_string(),
494 messages: full_history,
495 temperature: req.temperature,
496 tools: req
497 .tools
498 .clone()
499 .into_iter()
500 .map(ToolDefinition::from)
501 .collect::<Vec<_>>(),
502 tool_choice,
503 additional_params: req.additional_params,
504 })
505 }
506}
507
508#[derive(Clone)]
509pub struct CompletionModel<T = reqwest::Client> {
510 client: Client<T>,
511 pub model: String,
513}
514
515impl<T> CompletionModel<T>
516where
517 T: HttpClientExt,
518{
519 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
520 Self {
521 client,
522 model: model.into(),
523 }
524 }
525
526 pub fn with_model(client: Client<T>, model: &str) -> Self {
527 Self {
528 client,
529 model: model.into(),
530 }
531 }
532}
533
534impl<T> completion::CompletionModel for CompletionModel<T>
535where
536 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
537{
538 type Response = CompletionResponse;
539 type StreamingResponse = openai::StreamingCompletionResponse;
540
541 type Client = Client<T>;
542
543 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
544 Self::new(client.clone(), model.into())
545 }
546
547 async fn completion(
548 &self,
549 completion_request: CompletionRequest,
550 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
551 let span = if tracing::Span::current().is_disabled() {
552 info_span!(
553 target: "rig::completions",
554 "chat",
555 gen_ai.operation.name = "chat",
556 gen_ai.provider.name = "galadriel",
557 gen_ai.request.model = self.model,
558 gen_ai.system_instructions = tracing::field::Empty,
559 gen_ai.response.id = tracing::field::Empty,
560 gen_ai.response.model = tracing::field::Empty,
561 gen_ai.usage.output_tokens = tracing::field::Empty,
562 gen_ai.usage.input_tokens = tracing::field::Empty,
563 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
564 )
565 } else {
566 tracing::Span::current()
567 };
568
569 span.record("gen_ai.system_instructions", &completion_request.preamble);
570
571 let request =
572 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
573
574 if enabled!(tracing::Level::TRACE) {
575 tracing::trace!(target: "rig::completions",
576 "Galadriel completion request: {}",
577 serde_json::to_string_pretty(&request)?
578 );
579 }
580
581 let body = serde_json::to_vec(&request)?;
582
583 let req = self
584 .client
585 .post("/chat/completions")?
586 .body(body)
587 .map_err(http_client::Error::from)?;
588
589 async move {
590 let response = self.client.send(req).await?;
591
592 if response.status().is_success() {
593 let t = http_client::text(response).await?;
594
595 if enabled!(tracing::Level::TRACE) {
596 tracing::trace!(target: "rig::completions",
597 "Galadriel completion response: {}",
598 serde_json::to_string_pretty(&t)?
599 );
600 }
601
602 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
603 ApiResponse::Ok(response) => {
604 let span = tracing::Span::current();
605 span.record("gen_ai.response.id", response.id.clone());
606 span.record("gen_ai.response.model", response.model.clone());
607 if let Some(ref usage) = response.usage {
608 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
609 span.record(
610 "gen_ai.usage.output_tokens",
611 usage.total_tokens - usage.prompt_tokens,
612 );
613 }
614 response.try_into()
615 }
616 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
617 }
618 } else {
619 let text = http_client::text(response).await?;
620
621 Err(CompletionError::ProviderError(text))
622 }
623 }
624 .instrument(span)
625 .await
626 }
627
628 async fn stream(
629 &self,
630 completion_request: CompletionRequest,
631 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
632 let preamble = completion_request.preamble.clone();
633 let mut request =
634 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
635
636 let params = json_utils::merge(
637 request.additional_params.unwrap_or(serde_json::json!({})),
638 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
639 );
640
641 request.additional_params = Some(params);
642
643 let body = serde_json::to_vec(&request)?;
644
645 let req = self
646 .client
647 .post("/chat/completions")?
648 .body(body)
649 .map_err(http_client::Error::from)?;
650
651 let span = if tracing::Span::current().is_disabled() {
652 info_span!(
653 target: "rig::completions",
654 "chat_streaming",
655 gen_ai.operation.name = "chat_streaming",
656 gen_ai.provider.name = "galadriel",
657 gen_ai.request.model = self.model,
658 gen_ai.system_instructions = preamble,
659 gen_ai.response.id = tracing::field::Empty,
660 gen_ai.response.model = tracing::field::Empty,
661 gen_ai.usage.output_tokens = tracing::field::Empty,
662 gen_ai.usage.input_tokens = tracing::field::Empty,
663 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
664 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
665 gen_ai.output.messages = tracing::field::Empty,
666 )
667 } else {
668 tracing::Span::current()
669 };
670
671 send_compatible_streaming_request(self.client.clone(), req)
672 .instrument(span)
673 .await
674 }
675}
676#[cfg(test)]
677mod tests {
678 #[test]
679 fn test_client_initialization() {
680 let _client =
681 crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
682 let _client_from_builder = crate::providers::galadriel::Client::builder()
683 .api_key("dummy-key")
684 .build()
685 .expect("Client::builder() failed");
686 }
687}