1use super::openai;
14use crate::client::{
15 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16 ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37 fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42 fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48 type Builder = GaladrielBuilder;
49
50 const VERIFY_PATH: &'static str = "";
52}
53
54impl<H> Capabilities<H> for GaladrielExt {
55 type Completion = Capable<CompletionModel<H>>;
56 type Embeddings = Nothing;
57 type Transcription = Nothing;
58 type ModelListing = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Nothing;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for GaladrielExt {
66 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
67 std::iter::once((
68 "fine_tune_api_key",
69 (&self.fine_tune_api_key as &dyn std::fmt::Debug),
70 ))
71 }
72}
73
74impl ProviderBuilder for GaladrielBuilder {
75 type Extension<H>
76 = GaladrielExt
77 where
78 H: HttpClientExt;
79 type ApiKey = GaladrielApiKey;
80
81 const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
82
83 fn build<H>(
84 builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
85 ) -> http_client::Result<Self::Extension<H>>
86 where
87 H: HttpClientExt,
88 {
89 let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
90
91 Ok(GaladrielExt { fine_tune_api_key })
92 }
93}
94
95pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
96pub type ClientBuilder<H = reqwest::Client> =
97 client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
98
99impl<T> ClientBuilder<T> {
100 pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
101 where
102 S: AsRef<str>,
103 {
104 *self.ext_mut() = GaladrielBuilder {
105 fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
106 };
107
108 self
109 }
110}
111
112impl ProviderClient for Client {
113 type Input = (String, Option<String>);
114
115 fn from_env() -> Self {
119 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
120 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
121
122 let mut builder = Self::builder().api_key(api_key);
123
124 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
125 builder = builder.fine_tune_api_key(fine_tune_api_key);
126 }
127
128 builder.build().unwrap()
129 }
130
131 fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
132 let mut builder = Self::builder().api_key(api_key);
133
134 if let Some(fine_tune_key) = fine_tune_api_key {
135 builder = builder.fine_tune_api_key(fine_tune_key)
136 }
137
138 builder.build().unwrap()
139 }
140}
141
142#[derive(Debug, Deserialize)]
143struct ApiErrorResponse {
144 message: String,
145}
146
147#[derive(Debug, Deserialize)]
148#[serde(untagged)]
149enum ApiResponse<T> {
150 Ok(T),
151 Err(ApiErrorResponse),
152}
153
154#[derive(Clone, Debug, Deserialize, Serialize)]
155pub struct Usage {
156 pub prompt_tokens: usize,
157 pub total_tokens: usize,
158}
159
160impl std::fmt::Display for Usage {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(
163 f,
164 "Prompt tokens: {} Total tokens: {}",
165 self.prompt_tokens, self.total_tokens
166 )
167 }
168}
169
170pub const O1_PREVIEW: &str = "o1-preview";
176pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
178pub const O1_MINI: &str = "o1-mini";
180pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
182pub const GPT_4O: &str = "gpt-4o";
184pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
186pub const GPT_4_TURBO: &str = "gpt-4-turbo";
188pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
190pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
192pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
194pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
196pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
198pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
200pub const GPT_4: &str = "gpt-4";
202pub const GPT_4_0613: &str = "gpt-4-0613";
204pub const GPT_4_32K: &str = "gpt-4-32k";
206pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
208pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
210pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
212pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
214pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
216
217#[derive(Debug, Deserialize, Serialize)]
218pub struct CompletionResponse {
219 pub id: String,
220 pub object: String,
221 pub created: u64,
222 pub model: String,
223 pub system_fingerprint: Option<String>,
224 pub choices: Vec<Choice>,
225 pub usage: Option<Usage>,
226}
227
228impl From<ApiErrorResponse> for CompletionError {
229 fn from(err: ApiErrorResponse) -> Self {
230 CompletionError::ProviderError(err.message)
231 }
232}
233
234impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
235 type Error = CompletionError;
236
237 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
238 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
239 CompletionError::ResponseError("Response contained no choices".to_owned())
240 })?;
241
242 let mut content = message
243 .content
244 .as_ref()
245 .map(|c| vec![completion::AssistantContent::text(c)])
246 .unwrap_or_default();
247
248 content.extend(message.tool_calls.iter().map(|call| {
249 completion::AssistantContent::tool_call(
250 &call.function.name,
251 &call.function.name,
252 call.function.arguments.clone(),
253 )
254 }));
255
256 let choice = OneOrMany::many(content).map_err(|_| {
257 CompletionError::ResponseError(
258 "Response contained no message or tool call (empty)".to_owned(),
259 )
260 })?;
261 let usage = response
262 .usage
263 .as_ref()
264 .map(|usage| completion::Usage {
265 input_tokens: usage.prompt_tokens as u64,
266 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
267 total_tokens: usage.total_tokens as u64,
268 cached_input_tokens: 0,
269 cache_creation_input_tokens: 0,
270 })
271 .unwrap_or_default();
272
273 Ok(completion::CompletionResponse {
274 choice,
275 usage,
276 raw_response: response,
277 message_id: None,
278 })
279 }
280}
281
282#[derive(Debug, Deserialize, Serialize)]
283pub struct Choice {
284 pub index: usize,
285 pub message: Message,
286 pub logprobs: Option<serde_json::Value>,
287 pub finish_reason: String,
288}
289
290#[derive(Debug, Serialize, Deserialize)]
291pub struct Message {
292 pub role: String,
293 pub content: Option<String>,
294 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
295 pub tool_calls: Vec<openai::ToolCall>,
296}
297
298impl Message {
299 fn system(preamble: &str) -> Self {
300 Self {
301 role: "system".to_string(),
302 content: Some(preamble.to_string()),
303 tool_calls: Vec::new(),
304 }
305 }
306}
307
308impl TryFrom<Message> for message::Message {
309 type Error = message::MessageError;
310
311 fn try_from(message: Message) -> Result<Self, Self::Error> {
312 let tool_calls: Vec<message::ToolCall> = message
313 .tool_calls
314 .into_iter()
315 .map(|tool_call| tool_call.into())
316 .collect();
317
318 match message.role.as_str() {
319 "user" => Ok(Self::User {
320 content: OneOrMany::one(
321 message
322 .content
323 .map(|content| message::UserContent::text(&content))
324 .ok_or_else(|| {
325 message::MessageError::ConversionError("Empty user message".to_string())
326 })?,
327 ),
328 }),
329 "assistant" => Ok(Self::Assistant {
330 id: None,
331 content: OneOrMany::many(
332 tool_calls
333 .into_iter()
334 .map(message::AssistantContent::ToolCall)
335 .chain(
336 message
337 .content
338 .map(|content| message::AssistantContent::text(&content))
339 .into_iter(),
340 ),
341 )
342 .map_err(|_| {
343 message::MessageError::ConversionError("Empty assistant message".to_string())
344 })?,
345 }),
346 _ => Err(message::MessageError::ConversionError(format!(
347 "Unknown role: {}",
348 message.role
349 ))),
350 }
351 }
352}
353
354impl TryFrom<message::Message> for Message {
355 type Error = message::MessageError;
356
357 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
358 match message {
359 message::Message::System { content } => Ok(Self {
360 role: "system".to_string(),
361 content: Some(content),
362 tool_calls: vec![],
363 }),
364 message::Message::User { content } => Ok(Self {
365 role: "user".to_string(),
366 content: content.iter().find_map(|c| match c {
367 message::UserContent::Text(text) => Some(text.text.clone()),
368 _ => None,
369 }),
370 tool_calls: vec![],
371 }),
372 message::Message::Assistant { content, .. } => {
373 let mut text_content: Option<String> = None;
374 let mut tool_calls = vec![];
375
376 for c in content.iter() {
377 match c {
378 message::AssistantContent::Text(text) => {
379 text_content = Some(
380 text_content
381 .map(|mut existing| {
382 existing.push('\n');
383 existing.push_str(&text.text);
384 existing
385 })
386 .unwrap_or_else(|| text.text.clone()),
387 );
388 }
389 message::AssistantContent::ToolCall(tool_call) => {
390 tool_calls.push(tool_call.clone().into());
391 }
392 message::AssistantContent::Reasoning(_) => {
393 return Err(MessageError::ConversionError(
394 "Galadriel currently doesn't support reasoning.".into(),
395 ));
396 }
397 message::AssistantContent::Image(_) => {
398 return Err(MessageError::ConversionError(
399 "Galadriel currently doesn't support images.".into(),
400 ));
401 }
402 }
403 }
404
405 Ok(Self {
406 role: "assistant".to_string(),
407 content: text_content,
408 tool_calls,
409 })
410 }
411 }
412 }
413}
414
415#[derive(Clone, Debug, Deserialize, Serialize)]
416pub struct ToolDefinition {
417 pub r#type: String,
418 pub function: completion::ToolDefinition,
419}
420
421impl From<completion::ToolDefinition> for ToolDefinition {
422 fn from(tool: completion::ToolDefinition) -> Self {
423 Self {
424 r#type: "function".into(),
425 function: tool,
426 }
427 }
428}
429
430#[derive(Debug, Deserialize)]
431pub struct Function {
432 pub name: String,
433 pub arguments: String,
434}
435
436#[derive(Debug, Serialize, Deserialize)]
437pub(super) struct GaladrielCompletionRequest {
438 model: String,
439 pub messages: Vec<Message>,
440 #[serde(skip_serializing_if = "Option::is_none")]
441 temperature: Option<f64>,
442 #[serde(skip_serializing_if = "Vec::is_empty")]
443 tools: Vec<ToolDefinition>,
444 #[serde(skip_serializing_if = "Option::is_none")]
445 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
446 #[serde(flatten, skip_serializing_if = "Option::is_none")]
447 pub additional_params: Option<serde_json::Value>,
448}
449
450impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
451 type Error = CompletionError;
452
453 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
454 if req.output_schema.is_some() {
455 tracing::warn!("Structured outputs currently not supported for Galadriel");
456 }
457 let model = req.model.clone().unwrap_or_else(|| model.to_string());
458 let mut partial_history = vec![];
460 if let Some(docs) = req.normalized_documents() {
461 partial_history.push(docs);
462 }
463 partial_history.extend(req.chat_history);
464
465 let mut full_history: Vec<Message> = match &req.preamble {
467 Some(preamble) => vec![Message::system(preamble)],
468 None => vec![],
469 };
470
471 full_history.extend(
473 partial_history
474 .into_iter()
475 .map(message::Message::try_into)
476 .collect::<Result<Vec<Message>, _>>()?,
477 );
478
479 let tool_choice = req
480 .tool_choice
481 .clone()
482 .map(crate::providers::openai::completion::ToolChoice::try_from)
483 .transpose()?;
484
485 Ok(Self {
486 model: model.to_string(),
487 messages: full_history,
488 temperature: req.temperature,
489 tools: req
490 .tools
491 .clone()
492 .into_iter()
493 .map(ToolDefinition::from)
494 .collect::<Vec<_>>(),
495 tool_choice,
496 additional_params: req.additional_params,
497 })
498 }
499}
500
501#[derive(Clone)]
502pub struct CompletionModel<T = reqwest::Client> {
503 client: Client<T>,
504 pub model: String,
506}
507
508impl<T> CompletionModel<T>
509where
510 T: HttpClientExt,
511{
512 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
513 Self {
514 client,
515 model: model.into(),
516 }
517 }
518
519 pub fn with_model(client: Client<T>, model: &str) -> Self {
520 Self {
521 client,
522 model: model.into(),
523 }
524 }
525}
526
527impl<T> completion::CompletionModel for CompletionModel<T>
528where
529 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
530{
531 type Response = CompletionResponse;
532 type StreamingResponse = openai::StreamingCompletionResponse;
533
534 type Client = Client<T>;
535
536 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
537 Self::new(client.clone(), model.into())
538 }
539
540 async fn completion(
541 &self,
542 completion_request: CompletionRequest,
543 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
544 let span = if tracing::Span::current().is_disabled() {
545 info_span!(
546 target: "rig::completions",
547 "chat",
548 gen_ai.operation.name = "chat",
549 gen_ai.provider.name = "galadriel",
550 gen_ai.request.model = self.model,
551 gen_ai.system_instructions = tracing::field::Empty,
552 gen_ai.response.id = tracing::field::Empty,
553 gen_ai.response.model = tracing::field::Empty,
554 gen_ai.usage.output_tokens = tracing::field::Empty,
555 gen_ai.usage.input_tokens = tracing::field::Empty,
556 gen_ai.usage.cached_tokens = tracing::field::Empty,
557 )
558 } else {
559 tracing::Span::current()
560 };
561
562 span.record("gen_ai.system_instructions", &completion_request.preamble);
563
564 let request =
565 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
566
567 if enabled!(tracing::Level::TRACE) {
568 tracing::trace!(target: "rig::completions",
569 "Galadriel completion request: {}",
570 serde_json::to_string_pretty(&request)?
571 );
572 }
573
574 let body = serde_json::to_vec(&request)?;
575
576 let req = self
577 .client
578 .post("/chat/completions")?
579 .body(body)
580 .map_err(http_client::Error::from)?;
581
582 async move {
583 let response = self.client.send(req).await?;
584
585 if response.status().is_success() {
586 let t = http_client::text(response).await?;
587
588 if enabled!(tracing::Level::TRACE) {
589 tracing::trace!(target: "rig::completions",
590 "Galadriel completion response: {}",
591 serde_json::to_string_pretty(&t)?
592 );
593 }
594
595 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
596 ApiResponse::Ok(response) => {
597 let span = tracing::Span::current();
598 span.record("gen_ai.response.id", response.id.clone());
599 span.record("gen_ai.response.model_name", response.model.clone());
600 if let Some(ref usage) = response.usage {
601 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
602 span.record(
603 "gen_ai.usage.output_tokens",
604 usage.total_tokens - usage.prompt_tokens,
605 );
606 }
607 response.try_into()
608 }
609 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
610 }
611 } else {
612 let text = http_client::text(response).await?;
613
614 Err(CompletionError::ProviderError(text))
615 }
616 }
617 .instrument(span)
618 .await
619 }
620
621 async fn stream(
622 &self,
623 completion_request: CompletionRequest,
624 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
625 let preamble = completion_request.preamble.clone();
626 let mut request =
627 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
628
629 let params = json_utils::merge(
630 request.additional_params.unwrap_or(serde_json::json!({})),
631 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
632 );
633
634 request.additional_params = Some(params);
635
636 let body = serde_json::to_vec(&request)?;
637
638 let req = self
639 .client
640 .post("/chat/completions")?
641 .body(body)
642 .map_err(http_client::Error::from)?;
643
644 let span = if tracing::Span::current().is_disabled() {
645 info_span!(
646 target: "rig::completions",
647 "chat_streaming",
648 gen_ai.operation.name = "chat_streaming",
649 gen_ai.provider.name = "galadriel",
650 gen_ai.request.model = self.model,
651 gen_ai.system_instructions = preamble,
652 gen_ai.response.id = tracing::field::Empty,
653 gen_ai.response.model = tracing::field::Empty,
654 gen_ai.usage.output_tokens = tracing::field::Empty,
655 gen_ai.usage.input_tokens = tracing::field::Empty,
656 gen_ai.usage.cached_tokens = tracing::field::Empty,
657 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
658 gen_ai.output.messages = tracing::field::Empty,
659 )
660 } else {
661 tracing::Span::current()
662 };
663
664 send_compatible_streaming_request(self.client.clone(), req)
665 .instrument(span)
666 .await
667 }
668}
669#[cfg(test)]
670mod tests {
671 #[test]
672 fn test_client_initialization() {
673 let _client =
674 crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
675 let _client_from_builder = crate::providers::galadriel::Client::builder()
676 .api_key("dummy-key")
677 .build()
678 .expect("Client::builder() failed");
679 }
680}