1use super::openai;
14use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::message::MessageError;
18use crate::providers::openai::send_compatible_streaming_request;
19use crate::streaming::StreamingCompletionResponse;
20use crate::{
21 OneOrMany,
22 completion::{self, CompletionError, CompletionRequest},
23 impl_conversion_traits, json_utils, message,
24};
25use serde::{Deserialize, Serialize};
26use serde_json::{Value, json};
27use tracing::{Instrument, info_span};
28
29const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
33
34pub struct ClientBuilder<'a, T = reqwest::Client> {
35 api_key: &'a str,
36 fine_tune_api_key: Option<&'a str>,
37 base_url: &'a str,
38 http_client: T,
39}
40
41impl<'a, T> ClientBuilder<'a, T>
42where
43 T: Default,
44{
45 pub fn new(api_key: &'a str) -> Self {
46 Self {
47 api_key,
48 fine_tune_api_key: None,
49 base_url: GALADRIEL_API_BASE_URL,
50 http_client: Default::default(),
51 }
52 }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56 pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
57 self.fine_tune_api_key = Some(fine_tune_api_key);
58 self
59 }
60
61 pub fn base_url(mut self, base_url: &'a str) -> Self {
62 self.base_url = base_url;
63 self
64 }
65
66 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
67 ClientBuilder {
68 api_key: self.api_key,
69 fine_tune_api_key: self.fine_tune_api_key,
70 base_url: self.base_url,
71 http_client,
72 }
73 }
74
75 pub fn build(self) -> Client<T> {
76 Client {
77 base_url: self.base_url.to_string(),
78 api_key: self.api_key.to_string(),
79 fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
80 http_client: self.http_client,
81 }
82 }
83}
84#[derive(Clone)]
85pub struct Client<T = reqwest::Client> {
86 base_url: String,
87 api_key: String,
88 fine_tune_api_key: Option<String>,
89 http_client: T,
90}
91
92impl<T> std::fmt::Debug for Client<T>
93where
94 T: std::fmt::Debug,
95{
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("Client")
98 .field("base_url", &self.base_url)
99 .field("http_client", &self.http_client)
100 .field("api_key", &"<REDACTED>")
101 .field("fine_tune_api_key", &"<REDACTED>")
102 .finish()
103 }
104}
105
106impl<T> Client<T>
107where
108 T: HttpClientExt,
109{
110 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
111 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
112
113 let mut req = http_client::Request::post(url);
114
115 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
116 req = req.header("Fine-Tune-Authorization", fine_tune_key);
117 }
118
119 http_client::with_bearer_auth(req, &self.api_key)
120 }
121}
122
123impl Client<reqwest::Client> {
124 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
125 ClientBuilder::new(api_key)
126 }
127
128 pub fn new(api_key: &str) -> Self {
129 ClientBuilder::new(api_key).build()
130 }
131
132 pub fn from_env() -> Self {
133 <Self as ProviderClient>::from_env()
134 }
135}
136
137impl<T> ProviderClient for Client<T>
138where
139 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
140{
141 fn from_env() -> Self {
145 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
146 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
147 let mut builder = ClientBuilder::<T>::new(&api_key);
148 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
149 builder = builder.fine_tune_api_key(fine_tune_api_key);
150 }
151 builder.build()
152 }
153
154 fn from_val(input: crate::client::ProviderValue) -> Self {
155 let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
156 else {
157 panic!("Incorrect provider value type")
158 };
159 let mut builder = ClientBuilder::<T>::new(&api_key);
160 if let Some(fine_tune_key) = fine_tune_key.as_deref() {
161 builder = builder.fine_tune_api_key(fine_tune_key);
162 }
163 builder.build()
164 }
165}
166
167impl<T> CompletionClient for Client<T>
168where
169 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
170{
171 type CompletionModel = CompletionModel<T>;
172
173 fn completion_model(&self, model: &str) -> Self::CompletionModel {
185 CompletionModel::new(self.clone(), model)
186 }
187}
188
189impl<T> VerifyClient for Client<T>
190where
191 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
192{
193 #[cfg_attr(feature = "worker", worker::send)]
194 async fn verify(&self) -> Result<(), VerifyError> {
195 Ok(())
197 }
198}
199
200impl_conversion_traits!(
201 AsEmbeddings,
202 AsTranscription,
203 AsImageGeneration,
204 AsAudioGeneration for Client<T>
205);
206
207#[derive(Debug, Deserialize)]
208struct ApiErrorResponse {
209 message: String,
210}
211
212#[derive(Debug, Deserialize)]
213#[serde(untagged)]
214enum ApiResponse<T> {
215 Ok(T),
216 Err(ApiErrorResponse),
217}
218
219#[derive(Clone, Debug, Deserialize, Serialize)]
220pub struct Usage {
221 pub prompt_tokens: usize,
222 pub total_tokens: usize,
223}
224
225impl std::fmt::Display for Usage {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 write!(
228 f,
229 "Prompt tokens: {} Total tokens: {}",
230 self.prompt_tokens, self.total_tokens
231 )
232 }
233}
234
235pub const O1_PREVIEW: &str = "o1-preview";
240pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
242pub const O1_MINI: &str = "o1-mini";
244pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
246pub const GPT_4O: &str = "gpt-4o";
248pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
250pub const GPT_4_TURBO: &str = "gpt-4-turbo";
252pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
254pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
256pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
258pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
260pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
262pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
264pub const GPT_4: &str = "gpt-4";
266pub const GPT_4_0613: &str = "gpt-4-0613";
268pub const GPT_4_32K: &str = "gpt-4-32k";
270pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
272pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
274pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
276pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
278pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
280
281#[derive(Debug, Deserialize, Serialize)]
282pub struct CompletionResponse {
283 pub id: String,
284 pub object: String,
285 pub created: u64,
286 pub model: String,
287 pub system_fingerprint: Option<String>,
288 pub choices: Vec<Choice>,
289 pub usage: Option<Usage>,
290}
291
292impl From<ApiErrorResponse> for CompletionError {
293 fn from(err: ApiErrorResponse) -> Self {
294 CompletionError::ProviderError(err.message)
295 }
296}
297
298impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
299 type Error = CompletionError;
300
301 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
302 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
303 CompletionError::ResponseError("Response contained no choices".to_owned())
304 })?;
305
306 let mut content = message
307 .content
308 .as_ref()
309 .map(|c| vec![completion::AssistantContent::text(c)])
310 .unwrap_or_default();
311
312 content.extend(message.tool_calls.iter().map(|call| {
313 completion::AssistantContent::tool_call(
314 &call.function.name,
315 &call.function.name,
316 call.function.arguments.clone(),
317 )
318 }));
319
320 let choice = OneOrMany::many(content).map_err(|_| {
321 CompletionError::ResponseError(
322 "Response contained no message or tool call (empty)".to_owned(),
323 )
324 })?;
325 let usage = response
326 .usage
327 .as_ref()
328 .map(|usage| completion::Usage {
329 input_tokens: usage.prompt_tokens as u64,
330 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
331 total_tokens: usage.total_tokens as u64,
332 })
333 .unwrap_or_default();
334
335 Ok(completion::CompletionResponse {
336 choice,
337 usage,
338 raw_response: response,
339 })
340 }
341}
342
343#[derive(Debug, Deserialize, Serialize)]
344pub struct Choice {
345 pub index: usize,
346 pub message: Message,
347 pub logprobs: Option<serde_json::Value>,
348 pub finish_reason: String,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352pub struct Message {
353 pub role: String,
354 pub content: Option<String>,
355 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
356 pub tool_calls: Vec<openai::ToolCall>,
357}
358
359impl TryFrom<Message> for message::Message {
360 type Error = message::MessageError;
361
362 fn try_from(message: Message) -> Result<Self, Self::Error> {
363 let tool_calls: Vec<message::ToolCall> = message
364 .tool_calls
365 .into_iter()
366 .map(|tool_call| tool_call.into())
367 .collect();
368
369 match message.role.as_str() {
370 "user" => Ok(Self::User {
371 content: OneOrMany::one(
372 message
373 .content
374 .map(|content| message::UserContent::text(&content))
375 .ok_or_else(|| {
376 message::MessageError::ConversionError("Empty user message".to_string())
377 })?,
378 ),
379 }),
380 "assistant" => Ok(Self::Assistant {
381 id: None,
382 content: OneOrMany::many(
383 tool_calls
384 .into_iter()
385 .map(message::AssistantContent::ToolCall)
386 .chain(
387 message
388 .content
389 .map(|content| message::AssistantContent::text(&content))
390 .into_iter(),
391 ),
392 )
393 .map_err(|_| {
394 message::MessageError::ConversionError("Empty assistant message".to_string())
395 })?,
396 }),
397 _ => Err(message::MessageError::ConversionError(format!(
398 "Unknown role: {}",
399 message.role
400 ))),
401 }
402 }
403}
404
405impl TryFrom<message::Message> for Message {
406 type Error = message::MessageError;
407
408 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
409 match message {
410 message::Message::User { content } => Ok(Self {
411 role: "user".to_string(),
412 content: content.iter().find_map(|c| match c {
413 message::UserContent::Text(text) => Some(text.text.clone()),
414 _ => None,
415 }),
416 tool_calls: vec![],
417 }),
418 message::Message::Assistant { content, .. } => {
419 let mut text_content: Option<String> = None;
420 let mut tool_calls = vec![];
421
422 for c in content.iter() {
423 match c {
424 message::AssistantContent::Text(text) => {
425 text_content = Some(
426 text_content
427 .map(|mut existing| {
428 existing.push('\n');
429 existing.push_str(&text.text);
430 existing
431 })
432 .unwrap_or_else(|| text.text.clone()),
433 );
434 }
435 message::AssistantContent::ToolCall(tool_call) => {
436 tool_calls.push(tool_call.clone().into());
437 }
438 message::AssistantContent::Reasoning(_) => {
439 return Err(MessageError::ConversionError(
440 "Galadriel currently doesn't support reasoning.".into(),
441 ));
442 }
443 }
444 }
445
446 Ok(Self {
447 role: "assistant".to_string(),
448 content: text_content,
449 tool_calls,
450 })
451 }
452 }
453 }
454}
455
456#[derive(Clone, Debug, Deserialize, Serialize)]
457pub struct ToolDefinition {
458 pub r#type: String,
459 pub function: completion::ToolDefinition,
460}
461
462impl From<completion::ToolDefinition> for ToolDefinition {
463 fn from(tool: completion::ToolDefinition) -> Self {
464 Self {
465 r#type: "function".into(),
466 function: tool,
467 }
468 }
469}
470
471#[derive(Debug, Deserialize)]
472pub struct Function {
473 pub name: String,
474 pub arguments: String,
475}
476
477#[derive(Clone)]
478pub struct CompletionModel<T = reqwest::Client> {
479 client: Client<T>,
480 pub model: String,
482}
483
484impl<T> CompletionModel<T>
485where
486 T: HttpClientExt,
487{
488 pub fn new(client: Client<T>, model: &str) -> Self {
489 Self {
490 client,
491 model: model.to_string(),
492 }
493 }
494
495 pub(crate) fn create_completion_request(
496 &self,
497 completion_request: CompletionRequest,
498 ) -> Result<Value, CompletionError> {
499 let mut partial_history = vec![];
501 if let Some(docs) = completion_request.normalized_documents() {
502 partial_history.push(docs);
503 }
504 partial_history.extend(completion_request.chat_history);
505
506 let mut full_history: Vec<Message> = match &completion_request.preamble {
508 Some(preamble) => vec![Message {
509 role: "system".to_string(),
510 content: Some(preamble.to_string()),
511 tool_calls: vec![],
512 }],
513 None => vec![],
514 };
515
516 full_history.extend(
518 partial_history
519 .into_iter()
520 .map(message::Message::try_into)
521 .collect::<Result<Vec<Message>, _>>()?,
522 );
523
524 let tool_choice = completion_request
525 .tool_choice
526 .clone()
527 .map(crate::providers::openai::completion::ToolChoice::try_from)
528 .transpose()?;
529
530 let request = if completion_request.tools.is_empty() {
531 json!({
532 "model": self.model,
533 "messages": full_history,
534 "temperature": completion_request.temperature,
535 })
536 } else {
537 json!({
538 "model": self.model,
539 "messages": full_history,
540 "temperature": completion_request.temperature,
541 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
542 "tool_choice": tool_choice,
543 })
544 };
545
546 let request = if let Some(params) = completion_request.additional_params {
547 json_utils::merge(request, params)
548 } else {
549 request
550 };
551
552 Ok(request)
553 }
554}
555
556impl<T> completion::CompletionModel for CompletionModel<T>
557where
558 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
559{
560 type Response = CompletionResponse;
561 type StreamingResponse = openai::StreamingCompletionResponse;
562
563 #[cfg_attr(feature = "worker", worker::send)]
564 async fn completion(
565 &self,
566 completion_request: CompletionRequest,
567 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
568 let preamble = completion_request.preamble.clone();
569 let request = self.create_completion_request(completion_request)?;
570 let body = serde_json::to_vec(&request)?;
571
572 let req = self
573 .client
574 .post("/chat/completions")?
575 .header("Content-Type", "application/json")
576 .body(body)
577 .map_err(http_client::Error::from)?;
578
579 let span = if tracing::Span::current().is_disabled() {
580 info_span!(
581 target: "rig::completions",
582 "chat",
583 gen_ai.operation.name = "chat",
584 gen_ai.provider.name = "galadriel",
585 gen_ai.request.model = self.model,
586 gen_ai.system_instructions = preamble,
587 gen_ai.response.id = tracing::field::Empty,
588 gen_ai.response.model = tracing::field::Empty,
589 gen_ai.usage.output_tokens = tracing::field::Empty,
590 gen_ai.usage.input_tokens = tracing::field::Empty,
591 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
592 gen_ai.output.messages = tracing::field::Empty,
593 )
594 } else {
595 tracing::Span::current()
596 };
597
598 async move {
599 let response = self.client.http_client.send(req).await?;
600
601 if response.status().is_success() {
602 let t = http_client::text(response).await?;
603 tracing::debug!(target: "rig::completions", "Galadriel completion response: {t}");
604
605 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
606 ApiResponse::Ok(response) => {
607 let span = tracing::Span::current();
608 span.record("gen_ai.response.id", response.id.clone());
609 span.record("gen_ai.response.model_name", response.model.clone());
610 span.record(
611 "gen_ai.output.messages",
612 serde_json::to_string(&response.choices).unwrap(),
613 );
614 if let Some(ref usage) = response.usage {
615 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
616 span.record(
617 "gen_ai.usage.output_tokens",
618 usage.total_tokens - usage.prompt_tokens,
619 );
620 }
621 response.try_into()
622 }
623 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
624 }
625 } else {
626 let text = http_client::text(response).await?;
627
628 Err(CompletionError::ProviderError(text))
629 }
630 }
631 .instrument(span)
632 .await
633 }
634
635 #[cfg_attr(feature = "worker", worker::send)]
636 async fn stream(
637 &self,
638 request: CompletionRequest,
639 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
640 let preamble = request.preamble.clone();
641 let mut request = self.create_completion_request(request)?;
642
643 request = merge(
644 request,
645 json!({"stream": true, "stream_options": {"include_usage": true}}),
646 );
647
648 let body = serde_json::to_vec(&request)?;
649
650 let req = self
651 .client
652 .post("/chat/completions")?
653 .header("Content-Type", "application/json")
654 .body(body)
655 .map_err(http_client::Error::from)?;
656
657 let span = if tracing::Span::current().is_disabled() {
658 info_span!(
659 target: "rig::completions",
660 "chat_streaming",
661 gen_ai.operation.name = "chat_streaming",
662 gen_ai.provider.name = "galadriel",
663 gen_ai.request.model = self.model,
664 gen_ai.system_instructions = preamble,
665 gen_ai.response.id = tracing::field::Empty,
666 gen_ai.response.model = tracing::field::Empty,
667 gen_ai.usage.output_tokens = tracing::field::Empty,
668 gen_ai.usage.input_tokens = tracing::field::Empty,
669 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
670 gen_ai.output.messages = tracing::field::Empty,
671 )
672 } else {
673 tracing::Span::current()
674 };
675
676 send_compatible_streaming_request(self.client.http_client.clone(), req)
677 .instrument(span)
678 .await
679 }
680}