1use super::openai;
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::http_client::{self, HttpClientExt};
25use crate::message::MessageError;
26use crate::providers::openai::send_compatible_streaming_request;
27use crate::streaming::StreamingCompletionResponse;
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34use tracing::{Instrument, enabled, info_span};
35
36const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
40
41#[derive(Debug, Default, Clone)]
42pub struct GaladrielExt {
43 fine_tune_api_key: Option<String>,
44}
45
46#[derive(Debug, Default, Clone)]
47pub struct GaladrielBuilder {
48 fine_tune_api_key: Option<String>,
49}
50
51type GaladrielApiKey = BearerAuth;
52
53impl Provider for GaladrielExt {
54 type Builder = GaladrielBuilder;
55
56 const VERIFY_PATH: &'static str = "";
58}
59
60impl<H> Capabilities<H> for GaladrielExt {
61 type Completion = Capable<CompletionModel<H>>;
62 type Embeddings = Nothing;
63 type Transcription = Nothing;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Nothing;
67 #[cfg(feature = "audio")]
68 type AudioGeneration = Nothing;
69}
70
71impl DebugExt for GaladrielExt {
72 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
73 std::iter::once((
74 "fine_tune_api_key",
75 (&self.fine_tune_api_key as &dyn std::fmt::Debug),
76 ))
77 }
78}
79
80impl ProviderBuilder for GaladrielBuilder {
81 type Extension<H>
82 = GaladrielExt
83 where
84 H: HttpClientExt;
85 type ApiKey = GaladrielApiKey;
86
87 const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
88
89 fn build<H>(
90 builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
91 ) -> http_client::Result<Self::Extension<H>>
92 where
93 H: HttpClientExt,
94 {
95 let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
96
97 Ok(GaladrielExt { fine_tune_api_key })
98 }
99}
100
101pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
102pub type ClientBuilder<H = crate::markers::Missing> =
103 client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
104
105impl<T> ClientBuilder<T> {
106 pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
107 where
108 S: AsRef<str>,
109 {
110 *self.ext_mut() = GaladrielBuilder {
111 fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
112 };
113
114 self
115 }
116}
117
118impl ProviderClient for Client {
119 type Input = (String, Option<String>);
120 type Error = crate::client::ProviderClientError;
121
122 fn from_env() -> Result<Self, Self::Error> {
125 let api_key = crate::client::required_env_var("GALADRIEL_API_KEY")?;
126 let fine_tune_api_key = crate::client::optional_env_var("GALADRIEL_FINE_TUNE_API_KEY")?;
127
128 let mut builder = Self::builder().api_key(api_key);
129
130 if let Some(fine_tune_api_key) = fine_tune_api_key {
131 builder = builder.fine_tune_api_key(fine_tune_api_key);
132 }
133
134 builder.build().map_err(Into::into)
135 }
136
137 fn from_val((api_key, fine_tune_api_key): Self::Input) -> Result<Self, Self::Error> {
138 let mut builder = Self::builder().api_key(api_key);
139
140 if let Some(fine_tune_key) = fine_tune_api_key {
141 builder = builder.fine_tune_api_key(fine_tune_key)
142 }
143
144 builder.build().map_err(Into::into)
145 }
146}
147
148#[derive(Debug, Deserialize)]
149struct ApiErrorResponse {
150 message: String,
151}
152
153#[derive(Debug, Deserialize)]
154#[serde(untagged)]
155enum ApiResponse<T> {
156 Ok(T),
157 Err(ApiErrorResponse),
158}
159
160#[derive(Clone, Debug, Deserialize, Serialize)]
161pub struct Usage {
162 pub prompt_tokens: usize,
163 pub total_tokens: usize,
164}
165
166impl std::fmt::Display for Usage {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 write!(
169 f,
170 "Prompt tokens: {} Total tokens: {}",
171 self.prompt_tokens, self.total_tokens
172 )
173 }
174}
175
176pub const O1_PREVIEW: &str = "o1-preview";
182pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
184pub const O1_MINI: &str = "o1-mini";
186pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
188pub const GPT_4O: &str = "gpt-4o";
190pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
192pub const GPT_4_TURBO: &str = "gpt-4-turbo";
194pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
196pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
198pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
200pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
202pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
204pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
206pub const GPT_4: &str = "gpt-4";
208pub const GPT_4_0613: &str = "gpt-4-0613";
210pub const GPT_4_32K: &str = "gpt-4-32k";
212pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
214pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
216pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
218pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
220pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
222
223#[derive(Debug, Deserialize, Serialize)]
224pub struct CompletionResponse {
225 pub id: String,
226 pub object: String,
227 pub created: u64,
228 pub model: String,
229 pub system_fingerprint: Option<String>,
230 pub choices: Vec<Choice>,
231 pub usage: Option<Usage>,
232}
233
234impl From<ApiErrorResponse> for CompletionError {
235 fn from(err: ApiErrorResponse) -> Self {
236 CompletionError::ProviderError(err.message)
237 }
238}
239
240impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
241 type Error = CompletionError;
242
243 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
244 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
245 CompletionError::ResponseError("Response contained no choices".to_owned())
246 })?;
247
248 let mut content = message
249 .content
250 .as_ref()
251 .map(|c| vec![completion::AssistantContent::text(c)])
252 .unwrap_or_default();
253
254 content.extend(message.tool_calls.iter().map(|call| {
255 completion::AssistantContent::tool_call(
256 &call.function.name,
257 &call.function.name,
258 call.function.arguments.clone(),
259 )
260 }));
261
262 let choice = OneOrMany::many(content).map_err(|_| {
263 CompletionError::ResponseError(
264 "Response contained no message or tool call (empty)".to_owned(),
265 )
266 })?;
267 let usage = response
268 .usage
269 .as_ref()
270 .map(|usage| completion::Usage {
271 input_tokens: usage.prompt_tokens as u64,
272 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
273 total_tokens: usage.total_tokens as u64,
274 cached_input_tokens: 0,
275 cache_creation_input_tokens: 0,
276 tool_use_prompt_tokens: 0,
277 reasoning_tokens: 0,
278 })
279 .unwrap_or_default();
280
281 Ok(completion::CompletionResponse {
282 choice,
283 usage,
284 raw_response: response,
285 message_id: None,
286 })
287 }
288}
289
290#[derive(Debug, Deserialize, Serialize)]
291pub struct Choice {
292 pub index: usize,
293 pub message: Message,
294 pub logprobs: Option<serde_json::Value>,
295 pub finish_reason: String,
296}
297
298#[derive(Debug, Serialize, Deserialize)]
299pub struct Message {
300 pub role: String,
301 pub content: Option<String>,
302 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
303 pub tool_calls: Vec<openai::ToolCall>,
304}
305
306impl Message {
307 fn system(preamble: &str) -> Self {
308 Self {
309 role: "system".to_string(),
310 content: Some(preamble.to_string()),
311 tool_calls: Vec::new(),
312 }
313 }
314}
315
316impl TryFrom<Message> for message::Message {
317 type Error = message::MessageError;
318
319 fn try_from(message: Message) -> Result<Self, Self::Error> {
320 let tool_calls: Vec<message::ToolCall> = message
321 .tool_calls
322 .into_iter()
323 .map(|tool_call| tool_call.into())
324 .collect();
325
326 match message.role.as_str() {
327 "user" => Ok(Self::User {
328 content: OneOrMany::one(
329 message
330 .content
331 .map(|content| message::UserContent::text(&content))
332 .ok_or_else(|| {
333 message::MessageError::ConversionError("Empty user message".to_string())
334 })?,
335 ),
336 }),
337 "assistant" => Ok(Self::Assistant {
338 id: None,
339 content: OneOrMany::many(
340 tool_calls
341 .into_iter()
342 .map(message::AssistantContent::ToolCall)
343 .chain(
344 message
345 .content
346 .map(|content| message::AssistantContent::text(&content))
347 .into_iter(),
348 ),
349 )
350 .map_err(|_| {
351 message::MessageError::ConversionError("Empty assistant message".to_string())
352 })?,
353 }),
354 _ => Err(message::MessageError::ConversionError(format!(
355 "Unknown role: {}",
356 message.role
357 ))),
358 }
359 }
360}
361
362impl TryFrom<message::Message> for Message {
363 type Error = message::MessageError;
364
365 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
366 match message {
367 message::Message::System { content } => Ok(Self {
368 role: "system".to_string(),
369 content: Some(content),
370 tool_calls: vec![],
371 }),
372 message::Message::User { content } => Ok(Self {
373 role: "user".to_string(),
374 content: content.iter().find_map(|c| match c {
375 message::UserContent::Text(text) => Some(text.text.clone()),
376 _ => None,
377 }),
378 tool_calls: vec![],
379 }),
380 message::Message::Assistant { content, .. } => {
381 let mut text_content: Option<String> = None;
382 let mut tool_calls = vec![];
383
384 for c in content.iter() {
385 match c {
386 message::AssistantContent::Text(text) => {
387 text_content = Some(
388 text_content
389 .map(|mut existing| {
390 existing.push('\n');
391 existing.push_str(&text.text);
392 existing
393 })
394 .unwrap_or_else(|| text.text.clone()),
395 );
396 }
397 message::AssistantContent::ToolCall(tool_call) => {
398 tool_calls.push(tool_call.clone().into());
399 }
400 message::AssistantContent::Reasoning(_) => {
401 return Err(MessageError::ConversionError(
402 "Galadriel currently doesn't support reasoning.".into(),
403 ));
404 }
405 message::AssistantContent::Image(_) => {
406 return Err(MessageError::ConversionError(
407 "Galadriel currently doesn't support images.".into(),
408 ));
409 }
410 }
411 }
412
413 Ok(Self {
414 role: "assistant".to_string(),
415 content: text_content,
416 tool_calls,
417 })
418 }
419 }
420 }
421}
422
423#[derive(Clone, Debug, Deserialize, Serialize)]
424pub struct ToolDefinition {
425 pub r#type: String,
426 pub function: completion::ToolDefinition,
427}
428
429impl From<completion::ToolDefinition> for ToolDefinition {
430 fn from(tool: completion::ToolDefinition) -> Self {
431 Self {
432 r#type: "function".into(),
433 function: tool,
434 }
435 }
436}
437
438#[derive(Debug, Deserialize)]
439pub struct Function {
440 pub name: String,
441 pub arguments: String,
442}
443
444#[derive(Debug, Serialize, Deserialize)]
445pub(super) struct GaladrielCompletionRequest {
446 model: String,
447 pub messages: Vec<Message>,
448 #[serde(skip_serializing_if = "Option::is_none")]
449 temperature: Option<f64>,
450 #[serde(skip_serializing_if = "Vec::is_empty")]
451 tools: Vec<ToolDefinition>,
452 #[serde(skip_serializing_if = "Option::is_none")]
453 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
454 #[serde(flatten, skip_serializing_if = "Option::is_none")]
455 pub additional_params: Option<serde_json::Value>,
456}
457
458impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
459 type Error = CompletionError;
460
461 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
462 if req.output_schema.is_some() {
463 tracing::warn!("Structured outputs currently not supported for Galadriel");
464 }
465 let model = req.model.clone().unwrap_or_else(|| model.to_string());
466 let mut partial_history = vec![];
468 if let Some(docs) = req.normalized_documents() {
469 partial_history.push(docs);
470 }
471 partial_history.extend(req.chat_history);
472
473 let mut full_history: Vec<Message> = match &req.preamble {
475 Some(preamble) => vec![Message::system(preamble)],
476 None => vec![],
477 };
478
479 full_history.extend(
481 partial_history
482 .into_iter()
483 .map(message::Message::try_into)
484 .collect::<Result<Vec<Message>, _>>()?,
485 );
486
487 let tool_choice = req
488 .tool_choice
489 .clone()
490 .map(crate::providers::openai::completion::ToolChoice::try_from)
491 .transpose()?;
492
493 Ok(Self {
494 model: model.to_string(),
495 messages: full_history,
496 temperature: req.temperature,
497 tools: req
498 .tools
499 .clone()
500 .into_iter()
501 .map(ToolDefinition::from)
502 .collect::<Vec<_>>(),
503 tool_choice,
504 additional_params: req.additional_params,
505 })
506 }
507}
508
509#[derive(Clone)]
510pub struct CompletionModel<T = reqwest::Client> {
511 client: Client<T>,
512 pub model: String,
514}
515
516impl<T> CompletionModel<T>
517where
518 T: HttpClientExt,
519{
520 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
521 Self {
522 client,
523 model: model.into(),
524 }
525 }
526
527 pub fn with_model(client: Client<T>, model: &str) -> Self {
528 Self {
529 client,
530 model: model.into(),
531 }
532 }
533}
534
535impl<T> completion::CompletionModel for CompletionModel<T>
536where
537 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
538{
539 type Response = CompletionResponse;
540 type StreamingResponse = openai::StreamingCompletionResponse;
541
542 type Client = Client<T>;
543
544 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
545 Self::new(client.clone(), model.into())
546 }
547
548 async fn completion(
549 &self,
550 completion_request: CompletionRequest,
551 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
552 let span = if tracing::Span::current().is_disabled() {
553 info_span!(
554 target: "rig::completions",
555 "chat",
556 gen_ai.operation.name = "chat",
557 gen_ai.provider.name = "galadriel",
558 gen_ai.request.model = self.model,
559 gen_ai.system_instructions = tracing::field::Empty,
560 gen_ai.response.id = tracing::field::Empty,
561 gen_ai.response.model = tracing::field::Empty,
562 gen_ai.usage.output_tokens = tracing::field::Empty,
563 gen_ai.usage.input_tokens = tracing::field::Empty,
564 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
565 )
566 } else {
567 tracing::Span::current()
568 };
569
570 span.record("gen_ai.system_instructions", &completion_request.preamble);
571
572 let request =
573 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
574
575 if enabled!(tracing::Level::TRACE) {
576 tracing::trace!(target: "rig::completions",
577 "Galadriel completion request: {}",
578 serde_json::to_string_pretty(&request)?
579 );
580 }
581
582 let body = serde_json::to_vec(&request)?;
583
584 let req = self
585 .client
586 .post("/chat/completions")?
587 .body(body)
588 .map_err(http_client::Error::from)?;
589
590 async move {
591 let response = self.client.send(req).await?;
592
593 if response.status().is_success() {
594 let t = http_client::text(response).await?;
595
596 if enabled!(tracing::Level::TRACE) {
597 tracing::trace!(target: "rig::completions",
598 "Galadriel completion response: {}",
599 serde_json::to_string_pretty(&t)?
600 );
601 }
602
603 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
604 ApiResponse::Ok(response) => {
605 let span = tracing::Span::current();
606 span.record("gen_ai.response.id", response.id.clone());
607 span.record("gen_ai.response.model", response.model.clone());
608 if let Some(ref usage) = response.usage {
609 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
610 span.record(
611 "gen_ai.usage.output_tokens",
612 usage.total_tokens - usage.prompt_tokens,
613 );
614 }
615 response.try_into()
616 }
617 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
618 }
619 } else {
620 let text = http_client::text(response).await?;
621
622 Err(CompletionError::ProviderError(text))
623 }
624 }
625 .instrument(span)
626 .await
627 }
628
629 async fn stream(
630 &self,
631 completion_request: CompletionRequest,
632 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
633 let preamble = completion_request.preamble.clone();
634 let mut request =
635 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
636
637 let params = json_utils::merge(
638 request.additional_params.unwrap_or(serde_json::json!({})),
639 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
640 );
641
642 request.additional_params = Some(params);
643
644 let body = serde_json::to_vec(&request)?;
645
646 let req = self
647 .client
648 .post("/chat/completions")?
649 .body(body)
650 .map_err(http_client::Error::from)?;
651
652 let span = if tracing::Span::current().is_disabled() {
653 info_span!(
654 target: "rig::completions",
655 "chat_streaming",
656 gen_ai.operation.name = "chat_streaming",
657 gen_ai.provider.name = "galadriel",
658 gen_ai.request.model = self.model,
659 gen_ai.system_instructions = preamble,
660 gen_ai.response.id = tracing::field::Empty,
661 gen_ai.response.model = tracing::field::Empty,
662 gen_ai.usage.output_tokens = tracing::field::Empty,
663 gen_ai.usage.input_tokens = tracing::field::Empty,
664 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
665 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
666 gen_ai.output.messages = tracing::field::Empty,
667 )
668 } else {
669 tracing::Span::current()
670 };
671
672 send_compatible_streaming_request(self.client.clone(), req)
673 .instrument(span)
674 .await
675 }
676}
677#[cfg(test)]
678mod tests {
679 #[test]
680 fn test_client_initialization() {
681 let _client =
682 crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
683 let _client_from_builder = crate::providers::galadriel::Client::builder()
684 .api_key("dummy-key")
685 .build()
686 .expect("Client::builder() failed");
687 }
688}