1use super::openai;
14use crate::client::{
15 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16 ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37 fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42 fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48 type Builder = GaladrielBuilder;
49
50 const VERIFY_PATH: &'static str = "";
52
53 fn build<H>(
54 builder: &crate::client::ClientBuilder<
55 Self::Builder,
56 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
57 H,
58 >,
59 ) -> http_client::Result<Self> {
60 let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
61
62 Ok(Self { fine_tune_api_key })
63 }
64}
65
66impl<H> Capabilities<H> for GaladrielExt {
67 type Completion = Capable<CompletionModel<H>>;
68 type Embeddings = Nothing;
69 type Transcription = Nothing;
70 type ModelListing = Nothing;
71 #[cfg(feature = "image")]
72 type ImageGeneration = Nothing;
73 #[cfg(feature = "audio")]
74 type AudioGeneration = Nothing;
75}
76
77impl DebugExt for GaladrielExt {
78 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
79 std::iter::once((
80 "fine_tune_api_key",
81 (&self.fine_tune_api_key as &dyn std::fmt::Debug),
82 ))
83 }
84}
85
86impl ProviderBuilder for GaladrielBuilder {
87 type Output = GaladrielExt;
88 type ApiKey = GaladrielApiKey;
89
90 const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
94pub type ClientBuilder<H = reqwest::Client> =
95 client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
96
97impl<T> ClientBuilder<T> {
98 pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
99 where
100 S: AsRef<str>,
101 {
102 *self.ext_mut() = GaladrielBuilder {
103 fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
104 };
105
106 self
107 }
108}
109
110impl ProviderClient for Client {
111 type Input = (String, Option<String>);
112
113 fn from_env() -> Self {
117 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
118 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
119
120 let mut builder = Self::builder().api_key(api_key);
121
122 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
123 builder = builder.fine_tune_api_key(fine_tune_api_key);
124 }
125
126 builder.build().unwrap()
127 }
128
129 fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
130 let mut builder = Self::builder().api_key(api_key);
131
132 if let Some(fine_tune_key) = fine_tune_api_key {
133 builder = builder.fine_tune_api_key(fine_tune_key)
134 }
135
136 builder.build().unwrap()
137 }
138}
139
140#[derive(Debug, Deserialize)]
141struct ApiErrorResponse {
142 message: String,
143}
144
145#[derive(Debug, Deserialize)]
146#[serde(untagged)]
147enum ApiResponse<T> {
148 Ok(T),
149 Err(ApiErrorResponse),
150}
151
152#[derive(Clone, Debug, Deserialize, Serialize)]
153pub struct Usage {
154 pub prompt_tokens: usize,
155 pub total_tokens: usize,
156}
157
158impl std::fmt::Display for Usage {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 write!(
161 f,
162 "Prompt tokens: {} Total tokens: {}",
163 self.prompt_tokens, self.total_tokens
164 )
165 }
166}
167
168pub const O1_PREVIEW: &str = "o1-preview";
174pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
176pub const O1_MINI: &str = "o1-mini";
178pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
180pub const GPT_4O: &str = "gpt-4o";
182pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
184pub const GPT_4_TURBO: &str = "gpt-4-turbo";
186pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
188pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
190pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
192pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
194pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
196pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
198pub const GPT_4: &str = "gpt-4";
200pub const GPT_4_0613: &str = "gpt-4-0613";
202pub const GPT_4_32K: &str = "gpt-4-32k";
204pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
206pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
208pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
210pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
212pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
214
215#[derive(Debug, Deserialize, Serialize)]
216pub struct CompletionResponse {
217 pub id: String,
218 pub object: String,
219 pub created: u64,
220 pub model: String,
221 pub system_fingerprint: Option<String>,
222 pub choices: Vec<Choice>,
223 pub usage: Option<Usage>,
224}
225
226impl From<ApiErrorResponse> for CompletionError {
227 fn from(err: ApiErrorResponse) -> Self {
228 CompletionError::ProviderError(err.message)
229 }
230}
231
232impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
233 type Error = CompletionError;
234
235 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
236 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
237 CompletionError::ResponseError("Response contained no choices".to_owned())
238 })?;
239
240 let mut content = message
241 .content
242 .as_ref()
243 .map(|c| vec![completion::AssistantContent::text(c)])
244 .unwrap_or_default();
245
246 content.extend(message.tool_calls.iter().map(|call| {
247 completion::AssistantContent::tool_call(
248 &call.function.name,
249 &call.function.name,
250 call.function.arguments.clone(),
251 )
252 }));
253
254 let choice = OneOrMany::many(content).map_err(|_| {
255 CompletionError::ResponseError(
256 "Response contained no message or tool call (empty)".to_owned(),
257 )
258 })?;
259 let usage = response
260 .usage
261 .as_ref()
262 .map(|usage| completion::Usage {
263 input_tokens: usage.prompt_tokens as u64,
264 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
265 total_tokens: usage.total_tokens as u64,
266 cached_input_tokens: 0,
267 })
268 .unwrap_or_default();
269
270 Ok(completion::CompletionResponse {
271 choice,
272 usage,
273 raw_response: response,
274 message_id: None,
275 })
276 }
277}
278
279#[derive(Debug, Deserialize, Serialize)]
280pub struct Choice {
281 pub index: usize,
282 pub message: Message,
283 pub logprobs: Option<serde_json::Value>,
284 pub finish_reason: String,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct Message {
289 pub role: String,
290 pub content: Option<String>,
291 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
292 pub tool_calls: Vec<openai::ToolCall>,
293}
294
295impl Message {
296 fn system(preamble: &str) -> Self {
297 Self {
298 role: "system".to_string(),
299 content: Some(preamble.to_string()),
300 tool_calls: Vec::new(),
301 }
302 }
303}
304
305impl TryFrom<Message> for message::Message {
306 type Error = message::MessageError;
307
308 fn try_from(message: Message) -> Result<Self, Self::Error> {
309 let tool_calls: Vec<message::ToolCall> = message
310 .tool_calls
311 .into_iter()
312 .map(|tool_call| tool_call.into())
313 .collect();
314
315 match message.role.as_str() {
316 "user" => Ok(Self::User {
317 content: OneOrMany::one(
318 message
319 .content
320 .map(|content| message::UserContent::text(&content))
321 .ok_or_else(|| {
322 message::MessageError::ConversionError("Empty user message".to_string())
323 })?,
324 ),
325 }),
326 "assistant" => Ok(Self::Assistant {
327 id: None,
328 content: OneOrMany::many(
329 tool_calls
330 .into_iter()
331 .map(message::AssistantContent::ToolCall)
332 .chain(
333 message
334 .content
335 .map(|content| message::AssistantContent::text(&content))
336 .into_iter(),
337 ),
338 )
339 .map_err(|_| {
340 message::MessageError::ConversionError("Empty assistant message".to_string())
341 })?,
342 }),
343 _ => Err(message::MessageError::ConversionError(format!(
344 "Unknown role: {}",
345 message.role
346 ))),
347 }
348 }
349}
350
351impl TryFrom<message::Message> for Message {
352 type Error = message::MessageError;
353
354 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
355 match message {
356 message::Message::User { content } => Ok(Self {
357 role: "user".to_string(),
358 content: content.iter().find_map(|c| match c {
359 message::UserContent::Text(text) => Some(text.text.clone()),
360 _ => None,
361 }),
362 tool_calls: vec![],
363 }),
364 message::Message::Assistant { content, .. } => {
365 let mut text_content: Option<String> = None;
366 let mut tool_calls = vec![];
367
368 for c in content.iter() {
369 match c {
370 message::AssistantContent::Text(text) => {
371 text_content = Some(
372 text_content
373 .map(|mut existing| {
374 existing.push('\n');
375 existing.push_str(&text.text);
376 existing
377 })
378 .unwrap_or_else(|| text.text.clone()),
379 );
380 }
381 message::AssistantContent::ToolCall(tool_call) => {
382 tool_calls.push(tool_call.clone().into());
383 }
384 message::AssistantContent::Reasoning(_) => {
385 return Err(MessageError::ConversionError(
386 "Galadriel currently doesn't support reasoning.".into(),
387 ));
388 }
389 message::AssistantContent::Image(_) => {
390 return Err(MessageError::ConversionError(
391 "Galadriel currently doesn't support images.".into(),
392 ));
393 }
394 }
395 }
396
397 Ok(Self {
398 role: "assistant".to_string(),
399 content: text_content,
400 tool_calls,
401 })
402 }
403 }
404 }
405}
406
407#[derive(Clone, Debug, Deserialize, Serialize)]
408pub struct ToolDefinition {
409 pub r#type: String,
410 pub function: completion::ToolDefinition,
411}
412
413impl From<completion::ToolDefinition> for ToolDefinition {
414 fn from(tool: completion::ToolDefinition) -> Self {
415 Self {
416 r#type: "function".into(),
417 function: tool,
418 }
419 }
420}
421
422#[derive(Debug, Deserialize)]
423pub struct Function {
424 pub name: String,
425 pub arguments: String,
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429pub(super) struct GaladrielCompletionRequest {
430 model: String,
431 pub messages: Vec<Message>,
432 #[serde(skip_serializing_if = "Option::is_none")]
433 temperature: Option<f64>,
434 #[serde(skip_serializing_if = "Vec::is_empty")]
435 tools: Vec<ToolDefinition>,
436 #[serde(skip_serializing_if = "Option::is_none")]
437 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
438 #[serde(flatten, skip_serializing_if = "Option::is_none")]
439 pub additional_params: Option<serde_json::Value>,
440}
441
442impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
443 type Error = CompletionError;
444
445 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
446 if req.output_schema.is_some() {
447 tracing::warn!("Structured outputs currently not supported for Galadriel");
448 }
449 let model = req.model.clone().unwrap_or_else(|| model.to_string());
450 let mut partial_history = vec![];
452 if let Some(docs) = req.normalized_documents() {
453 partial_history.push(docs);
454 }
455 partial_history.extend(req.chat_history);
456
457 let mut full_history: Vec<Message> = match &req.preamble {
459 Some(preamble) => vec![Message::system(preamble)],
460 None => vec![],
461 };
462
463 full_history.extend(
465 partial_history
466 .into_iter()
467 .map(message::Message::try_into)
468 .collect::<Result<Vec<Message>, _>>()?,
469 );
470
471 let tool_choice = req
472 .tool_choice
473 .clone()
474 .map(crate::providers::openai::completion::ToolChoice::try_from)
475 .transpose()?;
476
477 Ok(Self {
478 model: model.to_string(),
479 messages: full_history,
480 temperature: req.temperature,
481 tools: req
482 .tools
483 .clone()
484 .into_iter()
485 .map(ToolDefinition::from)
486 .collect::<Vec<_>>(),
487 tool_choice,
488 additional_params: req.additional_params,
489 })
490 }
491}
492
493#[derive(Clone)]
494pub struct CompletionModel<T = reqwest::Client> {
495 client: Client<T>,
496 pub model: String,
498}
499
500impl<T> CompletionModel<T>
501where
502 T: HttpClientExt,
503{
504 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
505 Self {
506 client,
507 model: model.into(),
508 }
509 }
510
511 pub fn with_model(client: Client<T>, model: &str) -> Self {
512 Self {
513 client,
514 model: model.into(),
515 }
516 }
517}
518
519impl<T> completion::CompletionModel for CompletionModel<T>
520where
521 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
522{
523 type Response = CompletionResponse;
524 type StreamingResponse = openai::StreamingCompletionResponse;
525
526 type Client = Client<T>;
527
528 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
529 Self::new(client.clone(), model.into())
530 }
531
532 async fn completion(
533 &self,
534 completion_request: CompletionRequest,
535 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
536 let span = if tracing::Span::current().is_disabled() {
537 info_span!(
538 target: "rig::completions",
539 "chat",
540 gen_ai.operation.name = "chat",
541 gen_ai.provider.name = "galadriel",
542 gen_ai.request.model = self.model,
543 gen_ai.system_instructions = tracing::field::Empty,
544 gen_ai.response.id = tracing::field::Empty,
545 gen_ai.response.model = tracing::field::Empty,
546 gen_ai.usage.output_tokens = tracing::field::Empty,
547 gen_ai.usage.input_tokens = tracing::field::Empty,
548 )
549 } else {
550 tracing::Span::current()
551 };
552
553 span.record("gen_ai.system_instructions", &completion_request.preamble);
554
555 let request =
556 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
557
558 if enabled!(tracing::Level::TRACE) {
559 tracing::trace!(target: "rig::completions",
560 "Galadriel completion request: {}",
561 serde_json::to_string_pretty(&request)?
562 );
563 }
564
565 let body = serde_json::to_vec(&request)?;
566
567 let req = self
568 .client
569 .post("/chat/completions")?
570 .body(body)
571 .map_err(http_client::Error::from)?;
572
573 async move {
574 let response = self.client.send(req).await?;
575
576 if response.status().is_success() {
577 let t = http_client::text(response).await?;
578
579 if enabled!(tracing::Level::TRACE) {
580 tracing::trace!(target: "rig::completions",
581 "Galadriel completion response: {}",
582 serde_json::to_string_pretty(&t)?
583 );
584 }
585
586 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
587 ApiResponse::Ok(response) => {
588 let span = tracing::Span::current();
589 span.record("gen_ai.response.id", response.id.clone());
590 span.record("gen_ai.response.model_name", response.model.clone());
591 if let Some(ref usage) = response.usage {
592 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
593 span.record(
594 "gen_ai.usage.output_tokens",
595 usage.total_tokens - usage.prompt_tokens,
596 );
597 }
598 response.try_into()
599 }
600 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
601 }
602 } else {
603 let text = http_client::text(response).await?;
604
605 Err(CompletionError::ProviderError(text))
606 }
607 }
608 .instrument(span)
609 .await
610 }
611
612 async fn stream(
613 &self,
614 completion_request: CompletionRequest,
615 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
616 let preamble = completion_request.preamble.clone();
617 let mut request =
618 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
619
620 let params = json_utils::merge(
621 request.additional_params.unwrap_or(serde_json::json!({})),
622 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
623 );
624
625 request.additional_params = Some(params);
626
627 let body = serde_json::to_vec(&request)?;
628
629 let req = self
630 .client
631 .post("/chat/completions")?
632 .body(body)
633 .map_err(http_client::Error::from)?;
634
635 let span = if tracing::Span::current().is_disabled() {
636 info_span!(
637 target: "rig::completions",
638 "chat_streaming",
639 gen_ai.operation.name = "chat_streaming",
640 gen_ai.provider.name = "galadriel",
641 gen_ai.request.model = self.model,
642 gen_ai.system_instructions = preamble,
643 gen_ai.response.id = tracing::field::Empty,
644 gen_ai.response.model = tracing::field::Empty,
645 gen_ai.usage.output_tokens = tracing::field::Empty,
646 gen_ai.usage.input_tokens = tracing::field::Empty,
647 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
648 gen_ai.output.messages = tracing::field::Empty,
649 )
650 } else {
651 tracing::Span::current()
652 };
653
654 send_compatible_streaming_request(self.client.clone(), req)
655 .instrument(span)
656 .await
657 }
658}
659#[cfg(test)]
660mod tests {
661 #[test]
662 fn test_client_initialization() {
663 let _client: crate::providers::galadriel::Client =
664 crate::providers::galadriel::Client::new("dummy-key").expect("Client::new() failed");
665 let _client_from_builder: crate::providers::galadriel::Client =
666 crate::providers::galadriel::Client::builder()
667 .api_key("dummy-key")
668 .build()
669 .expect("Client::builder() failed");
670 }
671}