1use super::openai;
14use crate::client::{
15 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
16 ProviderClient,
17};
18use crate::http_client::{self, HttpClientExt};
19use crate::message::MessageError;
20use crate::providers::openai::send_compatible_streaming_request;
21use crate::streaming::StreamingCompletionResponse;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils, message,
26};
27use serde::{Deserialize, Serialize};
28use tracing::{Instrument, enabled, info_span};
29
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35#[derive(Debug, Default, Clone)]
36pub struct GaladrielExt {
37 fine_tune_api_key: Option<String>,
38}
39
40#[derive(Debug, Default, Clone)]
41pub struct GaladrielBuilder {
42 fine_tune_api_key: Option<String>,
43}
44
45type GaladrielApiKey = BearerAuth;
46
47impl Provider for GaladrielExt {
48 type Builder = GaladrielBuilder;
49
50 const VERIFY_PATH: &'static str = "";
52
53 fn build<H>(
54 builder: &crate::client::ClientBuilder<
55 Self::Builder,
56 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
57 H,
58 >,
59 ) -> http_client::Result<Self> {
60 let GaladrielBuilder { fine_tune_api_key } = builder.ext().clone();
61
62 Ok(Self { fine_tune_api_key })
63 }
64}
65
66impl<H> Capabilities<H> for GaladrielExt {
67 type Completion = Capable<CompletionModel<H>>;
68 type Embeddings = Nothing;
69 type Transcription = Nothing;
70 #[cfg(feature = "image")]
71 type ImageGeneration = Nothing;
72 #[cfg(feature = "audio")]
73 type AudioGeneration = Nothing;
74}
75
76impl DebugExt for GaladrielExt {
77 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
78 std::iter::once((
79 "fine_tune_api_key",
80 (&self.fine_tune_api_key as &dyn std::fmt::Debug),
81 ))
82 }
83}
84
85impl ProviderBuilder for GaladrielBuilder {
86 type Output = GaladrielExt;
87 type ApiKey = GaladrielApiKey;
88
89 const BASE_URL: &'static str = GALADRIEL_API_BASE_URL;
90}
91
92pub type Client<H = reqwest::Client> = client::Client<GaladrielExt, H>;
93pub type ClientBuilder<H = reqwest::Client> =
94 client::ClientBuilder<GaladrielBuilder, GaladrielApiKey, H>;
95
96impl<T> ClientBuilder<T> {
97 pub fn fine_tune_api_key<S>(mut self, fine_tune_api_key: S) -> Self
98 where
99 S: AsRef<str>,
100 {
101 *self.ext_mut() = GaladrielBuilder {
102 fine_tune_api_key: Some(fine_tune_api_key.as_ref().into()),
103 };
104
105 self
106 }
107}
108
109impl ProviderClient for Client {
110 type Input = (String, Option<String>);
111
112 fn from_env() -> Self {
116 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
117 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
118
119 let mut builder = Self::builder().api_key(api_key);
120
121 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
122 builder = builder.fine_tune_api_key(fine_tune_api_key);
123 }
124
125 builder.build().unwrap()
126 }
127
128 fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self {
129 let mut builder = Self::builder().api_key(api_key);
130
131 if let Some(fine_tune_key) = fine_tune_api_key {
132 builder = builder.fine_tune_api_key(fine_tune_key)
133 }
134
135 builder.build().unwrap()
136 }
137}
138
139#[derive(Debug, Deserialize)]
140struct ApiErrorResponse {
141 message: String,
142}
143
144#[derive(Debug, Deserialize)]
145#[serde(untagged)]
146enum ApiResponse<T> {
147 Ok(T),
148 Err(ApiErrorResponse),
149}
150
151#[derive(Clone, Debug, Deserialize, Serialize)]
152pub struct Usage {
153 pub prompt_tokens: usize,
154 pub total_tokens: usize,
155}
156
157impl std::fmt::Display for Usage {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 write!(
160 f,
161 "Prompt tokens: {} Total tokens: {}",
162 self.prompt_tokens, self.total_tokens
163 )
164 }
165}
166
167pub const O1_PREVIEW: &str = "o1-preview";
173pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
175pub const O1_MINI: &str = "o1-mini";
177pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
179pub const GPT_4O: &str = "gpt-4o";
181pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
183pub const GPT_4_TURBO: &str = "gpt-4-turbo";
185pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
187pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
189pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
191pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
193pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
195pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
197pub const GPT_4: &str = "gpt-4";
199pub const GPT_4_0613: &str = "gpt-4-0613";
201pub const GPT_4_32K: &str = "gpt-4-32k";
203pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
205pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
207pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
209pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
211pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
213
214#[derive(Debug, Deserialize, Serialize)]
215pub struct CompletionResponse {
216 pub id: String,
217 pub object: String,
218 pub created: u64,
219 pub model: String,
220 pub system_fingerprint: Option<String>,
221 pub choices: Vec<Choice>,
222 pub usage: Option<Usage>,
223}
224
225impl From<ApiErrorResponse> for CompletionError {
226 fn from(err: ApiErrorResponse) -> Self {
227 CompletionError::ProviderError(err.message)
228 }
229}
230
231impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
232 type Error = CompletionError;
233
234 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
235 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
236 CompletionError::ResponseError("Response contained no choices".to_owned())
237 })?;
238
239 let mut content = message
240 .content
241 .as_ref()
242 .map(|c| vec![completion::AssistantContent::text(c)])
243 .unwrap_or_default();
244
245 content.extend(message.tool_calls.iter().map(|call| {
246 completion::AssistantContent::tool_call(
247 &call.function.name,
248 &call.function.name,
249 call.function.arguments.clone(),
250 )
251 }));
252
253 let choice = OneOrMany::many(content).map_err(|_| {
254 CompletionError::ResponseError(
255 "Response contained no message or tool call (empty)".to_owned(),
256 )
257 })?;
258 let usage = response
259 .usage
260 .as_ref()
261 .map(|usage| completion::Usage {
262 input_tokens: usage.prompt_tokens as u64,
263 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
264 total_tokens: usage.total_tokens as u64,
265 cached_input_tokens: 0,
266 })
267 .unwrap_or_default();
268
269 Ok(completion::CompletionResponse {
270 choice,
271 usage,
272 raw_response: response,
273 })
274 }
275}
276
277#[derive(Debug, Deserialize, Serialize)]
278pub struct Choice {
279 pub index: usize,
280 pub message: Message,
281 pub logprobs: Option<serde_json::Value>,
282 pub finish_reason: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Message {
287 pub role: String,
288 pub content: Option<String>,
289 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
290 pub tool_calls: Vec<openai::ToolCall>,
291}
292
293impl Message {
294 fn system(preamble: &str) -> Self {
295 Self {
296 role: "system".to_string(),
297 content: Some(preamble.to_string()),
298 tool_calls: Vec::new(),
299 }
300 }
301}
302
303impl TryFrom<Message> for message::Message {
304 type Error = message::MessageError;
305
306 fn try_from(message: Message) -> Result<Self, Self::Error> {
307 let tool_calls: Vec<message::ToolCall> = message
308 .tool_calls
309 .into_iter()
310 .map(|tool_call| tool_call.into())
311 .collect();
312
313 match message.role.as_str() {
314 "user" => Ok(Self::User {
315 content: OneOrMany::one(
316 message
317 .content
318 .map(|content| message::UserContent::text(&content))
319 .ok_or_else(|| {
320 message::MessageError::ConversionError("Empty user message".to_string())
321 })?,
322 ),
323 }),
324 "assistant" => Ok(Self::Assistant {
325 id: None,
326 content: OneOrMany::many(
327 tool_calls
328 .into_iter()
329 .map(message::AssistantContent::ToolCall)
330 .chain(
331 message
332 .content
333 .map(|content| message::AssistantContent::text(&content))
334 .into_iter(),
335 ),
336 )
337 .map_err(|_| {
338 message::MessageError::ConversionError("Empty assistant message".to_string())
339 })?,
340 }),
341 _ => Err(message::MessageError::ConversionError(format!(
342 "Unknown role: {}",
343 message.role
344 ))),
345 }
346 }
347}
348
349impl TryFrom<message::Message> for Message {
350 type Error = message::MessageError;
351
352 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353 match message {
354 message::Message::User { content } => Ok(Self {
355 role: "user".to_string(),
356 content: content.iter().find_map(|c| match c {
357 message::UserContent::Text(text) => Some(text.text.clone()),
358 _ => None,
359 }),
360 tool_calls: vec![],
361 }),
362 message::Message::Assistant { content, .. } => {
363 let mut text_content: Option<String> = None;
364 let mut tool_calls = vec![];
365
366 for c in content.iter() {
367 match c {
368 message::AssistantContent::Text(text) => {
369 text_content = Some(
370 text_content
371 .map(|mut existing| {
372 existing.push('\n');
373 existing.push_str(&text.text);
374 existing
375 })
376 .unwrap_or_else(|| text.text.clone()),
377 );
378 }
379 message::AssistantContent::ToolCall(tool_call) => {
380 tool_calls.push(tool_call.clone().into());
381 }
382 message::AssistantContent::Reasoning(_) => {
383 return Err(MessageError::ConversionError(
384 "Galadriel currently doesn't support reasoning.".into(),
385 ));
386 }
387 message::AssistantContent::Image(_) => {
388 return Err(MessageError::ConversionError(
389 "Galadriel currently doesn't support images.".into(),
390 ));
391 }
392 }
393 }
394
395 Ok(Self {
396 role: "assistant".to_string(),
397 content: text_content,
398 tool_calls,
399 })
400 }
401 }
402 }
403}
404
405#[derive(Clone, Debug, Deserialize, Serialize)]
406pub struct ToolDefinition {
407 pub r#type: String,
408 pub function: completion::ToolDefinition,
409}
410
411impl From<completion::ToolDefinition> for ToolDefinition {
412 fn from(tool: completion::ToolDefinition) -> Self {
413 Self {
414 r#type: "function".into(),
415 function: tool,
416 }
417 }
418}
419
420#[derive(Debug, Deserialize)]
421pub struct Function {
422 pub name: String,
423 pub arguments: String,
424}
425
426#[derive(Debug, Serialize, Deserialize)]
427pub(super) struct GaladrielCompletionRequest {
428 model: String,
429 pub messages: Vec<Message>,
430 #[serde(skip_serializing_if = "Option::is_none")]
431 temperature: Option<f64>,
432 #[serde(skip_serializing_if = "Vec::is_empty")]
433 tools: Vec<ToolDefinition>,
434 #[serde(skip_serializing_if = "Option::is_none")]
435 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
436 #[serde(flatten, skip_serializing_if = "Option::is_none")]
437 pub additional_params: Option<serde_json::Value>,
438}
439
440impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest {
441 type Error = CompletionError;
442
443 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
444 let mut partial_history = vec![];
446 if let Some(docs) = req.normalized_documents() {
447 partial_history.push(docs);
448 }
449 partial_history.extend(req.chat_history);
450
451 let mut full_history: Vec<Message> = match &req.preamble {
453 Some(preamble) => vec![Message::system(preamble)],
454 None => vec![],
455 };
456
457 full_history.extend(
459 partial_history
460 .into_iter()
461 .map(message::Message::try_into)
462 .collect::<Result<Vec<Message>, _>>()?,
463 );
464
465 let tool_choice = req
466 .tool_choice
467 .clone()
468 .map(crate::providers::openai::completion::ToolChoice::try_from)
469 .transpose()?;
470
471 Ok(Self {
472 model: model.to_string(),
473 messages: full_history,
474 temperature: req.temperature,
475 tools: req
476 .tools
477 .clone()
478 .into_iter()
479 .map(ToolDefinition::from)
480 .collect::<Vec<_>>(),
481 tool_choice,
482 additional_params: req.additional_params,
483 })
484 }
485}
486
487#[derive(Clone)]
488pub struct CompletionModel<T = reqwest::Client> {
489 client: Client<T>,
490 pub model: String,
492}
493
494impl<T> CompletionModel<T>
495where
496 T: HttpClientExt,
497{
498 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
499 Self {
500 client,
501 model: model.into(),
502 }
503 }
504
505 pub fn with_model(client: Client<T>, model: &str) -> Self {
506 Self {
507 client,
508 model: model.into(),
509 }
510 }
511}
512
513impl<T> completion::CompletionModel for CompletionModel<T>
514where
515 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
516{
517 type Response = CompletionResponse;
518 type StreamingResponse = openai::StreamingCompletionResponse;
519
520 type Client = Client<T>;
521
522 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
523 Self::new(client.clone(), model.into())
524 }
525
526 async fn completion(
527 &self,
528 completion_request: CompletionRequest,
529 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
530 let span = if tracing::Span::current().is_disabled() {
531 info_span!(
532 target: "rig::completions",
533 "chat",
534 gen_ai.operation.name = "chat",
535 gen_ai.provider.name = "galadriel",
536 gen_ai.request.model = self.model,
537 gen_ai.system_instructions = tracing::field::Empty,
538 gen_ai.response.id = tracing::field::Empty,
539 gen_ai.response.model = tracing::field::Empty,
540 gen_ai.usage.output_tokens = tracing::field::Empty,
541 gen_ai.usage.input_tokens = tracing::field::Empty,
542 )
543 } else {
544 tracing::Span::current()
545 };
546
547 span.record("gen_ai.system_instructions", &completion_request.preamble);
548
549 let request =
550 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
551
552 if enabled!(tracing::Level::TRACE) {
553 tracing::trace!(target: "rig::completions",
554 "Galadriel completion request: {}",
555 serde_json::to_string_pretty(&request)?
556 );
557 }
558
559 let body = serde_json::to_vec(&request)?;
560
561 let req = self
562 .client
563 .post("/chat/completions")?
564 .body(body)
565 .map_err(http_client::Error::from)?;
566
567 async move {
568 let response = self.client.send(req).await?;
569
570 if response.status().is_success() {
571 let t = http_client::text(response).await?;
572
573 if enabled!(tracing::Level::TRACE) {
574 tracing::trace!(target: "rig::completions",
575 "Galadriel completion response: {}",
576 serde_json::to_string_pretty(&t)?
577 );
578 }
579
580 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
581 ApiResponse::Ok(response) => {
582 let span = tracing::Span::current();
583 span.record("gen_ai.response.id", response.id.clone());
584 span.record("gen_ai.response.model_name", response.model.clone());
585 if let Some(ref usage) = response.usage {
586 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
587 span.record(
588 "gen_ai.usage.output_tokens",
589 usage.total_tokens - usage.prompt_tokens,
590 );
591 }
592 response.try_into()
593 }
594 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
595 }
596 } else {
597 let text = http_client::text(response).await?;
598
599 Err(CompletionError::ProviderError(text))
600 }
601 }
602 .instrument(span)
603 .await
604 }
605
606 async fn stream(
607 &self,
608 completion_request: CompletionRequest,
609 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
610 let preamble = completion_request.preamble.clone();
611 let mut request =
612 GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
613
614 let params = json_utils::merge(
615 request.additional_params.unwrap_or(serde_json::json!({})),
616 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
617 );
618
619 request.additional_params = Some(params);
620
621 let body = serde_json::to_vec(&request)?;
622
623 let req = self
624 .client
625 .post("/chat/completions")?
626 .body(body)
627 .map_err(http_client::Error::from)?;
628
629 let span = if tracing::Span::current().is_disabled() {
630 info_span!(
631 target: "rig::completions",
632 "chat_streaming",
633 gen_ai.operation.name = "chat_streaming",
634 gen_ai.provider.name = "galadriel",
635 gen_ai.request.model = self.model,
636 gen_ai.system_instructions = preamble,
637 gen_ai.response.id = tracing::field::Empty,
638 gen_ai.response.model = tracing::field::Empty,
639 gen_ai.usage.output_tokens = tracing::field::Empty,
640 gen_ai.usage.input_tokens = tracing::field::Empty,
641 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
642 gen_ai.output.messages = tracing::field::Empty,
643 )
644 } else {
645 tracing::Span::current()
646 };
647
648 send_compatible_streaming_request(self.client.clone(), req)
649 .instrument(span)
650 .await
651 }
652}