1use super::openai;
14use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::message::MessageError;
18use crate::providers::openai::send_compatible_streaming_request;
19use crate::streaming::StreamingCompletionResponse;
20use crate::{
21 OneOrMany,
22 completion::{self, CompletionError, CompletionRequest},
23 impl_conversion_traits, json_utils, message,
24};
25use bytes::Bytes;
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28use tracing::{Instrument, info_span};
29
30const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
34
35pub struct ClientBuilder<'a, T = reqwest::Client> {
36 api_key: &'a str,
37 fine_tune_api_key: Option<&'a str>,
38 base_url: &'a str,
39 http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44 T: Default,
45{
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 fine_tune_api_key: None,
50 base_url: GALADRIEL_API_BASE_URL,
51 http_client: Default::default(),
52 }
53 }
54}
55
56impl<'a, T> ClientBuilder<'a, T> {
57 pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
58 self.fine_tune_api_key = Some(fine_tune_api_key);
59 self
60 }
61
62 pub fn base_url(mut self, base_url: &'a str) -> Self {
63 self.base_url = base_url;
64 self
65 }
66
67 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
68 ClientBuilder {
69 api_key: self.api_key,
70 fine_tune_api_key: self.fine_tune_api_key,
71 base_url: self.base_url,
72 http_client,
73 }
74 }
75
76 pub fn build(self) -> Client<T> {
77 Client {
78 base_url: self.base_url.to_string(),
79 api_key: self.api_key.to_string(),
80 fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
81 http_client: self.http_client,
82 }
83 }
84}
85#[derive(Clone)]
86pub struct Client<T = reqwest::Client> {
87 base_url: String,
88 api_key: String,
89 fine_tune_api_key: Option<String>,
90 http_client: T,
91}
92
93impl<T> std::fmt::Debug for Client<T>
94where
95 T: std::fmt::Debug,
96{
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Client")
99 .field("base_url", &self.base_url)
100 .field("http_client", &self.http_client)
101 .field("api_key", &"<REDACTED>")
102 .field("fine_tune_api_key", &"<REDACTED>")
103 .finish()
104 }
105}
106
107impl<T> Client<T>
108where
109 T: Default,
110{
111 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
122 ClientBuilder::new(api_key)
123 }
124
125 pub fn new(api_key: &str) -> Self {
130 Self::builder(api_key).build()
131 }
132}
133
134impl<T> Client<T>
135where
136 T: HttpClientExt,
137{
138 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
139 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
140
141 let mut req = http_client::Request::post(url);
142
143 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
144 req = req.header("Fine-Tune-Authorization", fine_tune_key);
145 }
146
147 http_client::with_bearer_auth(req, &self.api_key)
148 }
149
150 async fn send<U, R>(
151 &self,
152 req: http_client::Request<U>,
153 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
154 where
155 U: Into<Bytes> + Send,
156 R: From<Bytes> + Send + 'static,
157 {
158 self.http_client.send(req).await
159 }
160}
161
162impl Client<reqwest::Client> {
163 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
164 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
165 let mut req = self.http_client.post(url).bearer_auth(&self.api_key);
166
167 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
168 req = req.header("Fine-Tune-Authorization", fine_tune_key)
169 }
170
171 req
172 }
173}
174
175impl ProviderClient for Client<reqwest::Client> {
176 fn from_env() -> Self {
180 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
181 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
182 let mut builder = Self::builder(&api_key);
183 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
184 builder = builder.fine_tune_api_key(fine_tune_api_key);
185 }
186 builder.build()
187 }
188
189 fn from_val(input: crate::client::ProviderValue) -> Self {
190 let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
191 else {
192 panic!("Incorrect provider value type")
193 };
194 let mut builder = Self::builder(&api_key);
195 if let Some(fine_tune_key) = fine_tune_key.as_deref() {
196 builder = builder.fine_tune_api_key(fine_tune_key);
197 }
198 builder.build()
199 }
200}
201
202impl CompletionClient for Client<reqwest::Client> {
203 type CompletionModel = CompletionModel<reqwest::Client>;
204
205 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
217 CompletionModel::new(self.clone(), model)
218 }
219}
220
221impl VerifyClient for Client<reqwest::Client> {
222 #[cfg_attr(feature = "worker", worker::send)]
223 async fn verify(&self) -> Result<(), VerifyError> {
224 Ok(())
226 }
227}
228
229impl_conversion_traits!(
230 AsEmbeddings,
231 AsTranscription,
232 AsImageGeneration,
233 AsAudioGeneration for Client<T>
234);
235
236#[derive(Debug, Deserialize)]
237struct ApiErrorResponse {
238 message: String,
239}
240
241#[derive(Debug, Deserialize)]
242#[serde(untagged)]
243enum ApiResponse<T> {
244 Ok(T),
245 Err(ApiErrorResponse),
246}
247
248#[derive(Clone, Debug, Deserialize, Serialize)]
249pub struct Usage {
250 pub prompt_tokens: usize,
251 pub total_tokens: usize,
252}
253
254impl std::fmt::Display for Usage {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 write!(
257 f,
258 "Prompt tokens: {} Total tokens: {}",
259 self.prompt_tokens, self.total_tokens
260 )
261 }
262}
263
264pub const O1_PREVIEW: &str = "o1-preview";
269pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
271pub const O1_MINI: &str = "o1-mini";
273pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
275pub const GPT_4O: &str = "gpt-4o";
277pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
279pub const GPT_4_TURBO: &str = "gpt-4-turbo";
281pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
283pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
285pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
287pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
289pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
291pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
293pub const GPT_4: &str = "gpt-4";
295pub const GPT_4_0613: &str = "gpt-4-0613";
297pub const GPT_4_32K: &str = "gpt-4-32k";
299pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
301pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
303pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
305pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
307pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
309
310#[derive(Debug, Deserialize, Serialize)]
311pub struct CompletionResponse {
312 pub id: String,
313 pub object: String,
314 pub created: u64,
315 pub model: String,
316 pub system_fingerprint: Option<String>,
317 pub choices: Vec<Choice>,
318 pub usage: Option<Usage>,
319}
320
321impl From<ApiErrorResponse> for CompletionError {
322 fn from(err: ApiErrorResponse) -> Self {
323 CompletionError::ProviderError(err.message)
324 }
325}
326
327impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
328 type Error = CompletionError;
329
330 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
331 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
332 CompletionError::ResponseError("Response contained no choices".to_owned())
333 })?;
334
335 let mut content = message
336 .content
337 .as_ref()
338 .map(|c| vec![completion::AssistantContent::text(c)])
339 .unwrap_or_default();
340
341 content.extend(message.tool_calls.iter().map(|call| {
342 completion::AssistantContent::tool_call(
343 &call.function.name,
344 &call.function.name,
345 call.function.arguments.clone(),
346 )
347 }));
348
349 let choice = OneOrMany::many(content).map_err(|_| {
350 CompletionError::ResponseError(
351 "Response contained no message or tool call (empty)".to_owned(),
352 )
353 })?;
354 let usage = response
355 .usage
356 .as_ref()
357 .map(|usage| completion::Usage {
358 input_tokens: usage.prompt_tokens as u64,
359 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
360 total_tokens: usage.total_tokens as u64,
361 })
362 .unwrap_or_default();
363
364 Ok(completion::CompletionResponse {
365 choice,
366 usage,
367 raw_response: response,
368 })
369 }
370}
371
372#[derive(Debug, Deserialize, Serialize)]
373pub struct Choice {
374 pub index: usize,
375 pub message: Message,
376 pub logprobs: Option<serde_json::Value>,
377 pub finish_reason: String,
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381pub struct Message {
382 pub role: String,
383 pub content: Option<String>,
384 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
385 pub tool_calls: Vec<openai::ToolCall>,
386}
387
388impl TryFrom<Message> for message::Message {
389 type Error = message::MessageError;
390
391 fn try_from(message: Message) -> Result<Self, Self::Error> {
392 let tool_calls: Vec<message::ToolCall> = message
393 .tool_calls
394 .into_iter()
395 .map(|tool_call| tool_call.into())
396 .collect();
397
398 match message.role.as_str() {
399 "user" => Ok(Self::User {
400 content: OneOrMany::one(
401 message
402 .content
403 .map(|content| message::UserContent::text(&content))
404 .ok_or_else(|| {
405 message::MessageError::ConversionError("Empty user message".to_string())
406 })?,
407 ),
408 }),
409 "assistant" => Ok(Self::Assistant {
410 id: None,
411 content: OneOrMany::many(
412 tool_calls
413 .into_iter()
414 .map(message::AssistantContent::ToolCall)
415 .chain(
416 message
417 .content
418 .map(|content| message::AssistantContent::text(&content))
419 .into_iter(),
420 ),
421 )
422 .map_err(|_| {
423 message::MessageError::ConversionError("Empty assistant message".to_string())
424 })?,
425 }),
426 _ => Err(message::MessageError::ConversionError(format!(
427 "Unknown role: {}",
428 message.role
429 ))),
430 }
431 }
432}
433
434impl TryFrom<message::Message> for Message {
435 type Error = message::MessageError;
436
437 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
438 match message {
439 message::Message::User { content } => Ok(Self {
440 role: "user".to_string(),
441 content: content.iter().find_map(|c| match c {
442 message::UserContent::Text(text) => Some(text.text.clone()),
443 _ => None,
444 }),
445 tool_calls: vec![],
446 }),
447 message::Message::Assistant { content, .. } => {
448 let mut text_content: Option<String> = None;
449 let mut tool_calls = vec![];
450
451 for c in content.iter() {
452 match c {
453 message::AssistantContent::Text(text) => {
454 text_content = Some(
455 text_content
456 .map(|mut existing| {
457 existing.push('\n');
458 existing.push_str(&text.text);
459 existing
460 })
461 .unwrap_or_else(|| text.text.clone()),
462 );
463 }
464 message::AssistantContent::ToolCall(tool_call) => {
465 tool_calls.push(tool_call.clone().into());
466 }
467 message::AssistantContent::Reasoning(_) => {
468 return Err(MessageError::ConversionError(
469 "Galadriel currently doesn't support reasoning.".into(),
470 ));
471 }
472 }
473 }
474
475 Ok(Self {
476 role: "assistant".to_string(),
477 content: text_content,
478 tool_calls,
479 })
480 }
481 }
482 }
483}
484
485#[derive(Clone, Debug, Deserialize, Serialize)]
486pub struct ToolDefinition {
487 pub r#type: String,
488 pub function: completion::ToolDefinition,
489}
490
491impl From<completion::ToolDefinition> for ToolDefinition {
492 fn from(tool: completion::ToolDefinition) -> Self {
493 Self {
494 r#type: "function".into(),
495 function: tool,
496 }
497 }
498}
499
500#[derive(Debug, Deserialize)]
501pub struct Function {
502 pub name: String,
503 pub arguments: String,
504}
505
506#[derive(Clone)]
507pub struct CompletionModel<T = reqwest::Client> {
508 client: Client<T>,
509 pub model: String,
511}
512
513impl<T> CompletionModel<T>
514where
515 T: HttpClientExt,
516{
517 pub fn new(client: Client<T>, model: &str) -> Self {
518 Self {
519 client,
520 model: model.to_string(),
521 }
522 }
523
524 pub(crate) fn create_completion_request(
525 &self,
526 completion_request: CompletionRequest,
527 ) -> Result<Value, CompletionError> {
528 let mut partial_history = vec![];
530 if let Some(docs) = completion_request.normalized_documents() {
531 partial_history.push(docs);
532 }
533 partial_history.extend(completion_request.chat_history);
534
535 let mut full_history: Vec<Message> = match &completion_request.preamble {
537 Some(preamble) => vec![Message {
538 role: "system".to_string(),
539 content: Some(preamble.to_string()),
540 tool_calls: vec![],
541 }],
542 None => vec![],
543 };
544
545 full_history.extend(
547 partial_history
548 .into_iter()
549 .map(message::Message::try_into)
550 .collect::<Result<Vec<Message>, _>>()?,
551 );
552
553 let tool_choice = completion_request
554 .tool_choice
555 .clone()
556 .map(crate::providers::openai::completion::ToolChoice::try_from)
557 .transpose()?;
558
559 let request = if completion_request.tools.is_empty() {
560 json!({
561 "model": self.model,
562 "messages": full_history,
563 "temperature": completion_request.temperature,
564 })
565 } else {
566 json!({
567 "model": self.model,
568 "messages": full_history,
569 "temperature": completion_request.temperature,
570 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
571 "tool_choice": tool_choice,
572 })
573 };
574
575 let request = if let Some(params) = completion_request.additional_params {
576 json_utils::merge(request, params)
577 } else {
578 request
579 };
580
581 Ok(request)
582 }
583}
584
585impl completion::CompletionModel for CompletionModel<reqwest::Client> {
586 type Response = CompletionResponse;
587 type StreamingResponse = openai::StreamingCompletionResponse;
588
589 #[cfg_attr(feature = "worker", worker::send)]
590 async fn completion(
591 &self,
592 completion_request: CompletionRequest,
593 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
594 let preamble = completion_request.preamble.clone();
595 let request = self.create_completion_request(completion_request)?;
596 let body = serde_json::to_vec(&request)?;
597
598 let req = self
599 .client
600 .post("/chat/completions")?
601 .header("Content-Type", "application/json")
602 .body(body)
603 .map_err(http_client::Error::from)?;
604
605 let span = if tracing::Span::current().is_disabled() {
606 info_span!(
607 target: "rig::completions",
608 "chat",
609 gen_ai.operation.name = "chat",
610 gen_ai.provider.name = "galadriel",
611 gen_ai.request.model = self.model,
612 gen_ai.system_instructions = preamble,
613 gen_ai.response.id = tracing::field::Empty,
614 gen_ai.response.model = tracing::field::Empty,
615 gen_ai.usage.output_tokens = tracing::field::Empty,
616 gen_ai.usage.input_tokens = tracing::field::Empty,
617 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
618 gen_ai.output.messages = tracing::field::Empty,
619 )
620 } else {
621 tracing::Span::current()
622 };
623
624 async move {
625 let response = self.client.send(req).await?;
626
627 if response.status().is_success() {
628 let t = http_client::text(response).await?;
629 tracing::debug!(target: "rig::completions", "Galadriel completion response: {t}");
630
631 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
632 ApiResponse::Ok(response) => {
633 let span = tracing::Span::current();
634 span.record("gen_ai.response.id", response.id.clone());
635 span.record("gen_ai.response.model_name", response.model.clone());
636 span.record(
637 "gen_ai.output.messages",
638 serde_json::to_string(&response.choices).unwrap(),
639 );
640 if let Some(ref usage) = response.usage {
641 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
642 span.record(
643 "gen_ai.usage.output_tokens",
644 usage.total_tokens - usage.prompt_tokens,
645 );
646 }
647 response.try_into()
648 }
649 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
650 }
651 } else {
652 let text = http_client::text(response).await?;
653
654 Err(CompletionError::ProviderError(text))
655 }
656 }
657 .instrument(span)
658 .await
659 }
660
661 #[cfg_attr(feature = "worker", worker::send)]
662 async fn stream(
663 &self,
664 request: CompletionRequest,
665 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
666 let preamble = request.preamble.clone();
667 let mut request = self.create_completion_request(request)?;
668
669 request = merge(
670 request,
671 json!({"stream": true, "stream_options": {"include_usage": true}}),
672 );
673
674 let builder = self
675 .client
676 .reqwest_post("/chat/completions")
677 .header("Content-Type", "application/json")
678 .json(&request);
679
680 let span = if tracing::Span::current().is_disabled() {
681 info_span!(
682 target: "rig::completions",
683 "chat_streaming",
684 gen_ai.operation.name = "chat_streaming",
685 gen_ai.provider.name = "galadriel",
686 gen_ai.request.model = self.model,
687 gen_ai.system_instructions = preamble,
688 gen_ai.response.id = tracing::field::Empty,
689 gen_ai.response.model = tracing::field::Empty,
690 gen_ai.usage.output_tokens = tracing::field::Empty,
691 gen_ai.usage.input_tokens = tracing::field::Empty,
692 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
693 gen_ai.output.messages = tracing::field::Empty,
694 )
695 } else {
696 tracing::Span::current()
697 };
698
699 send_compatible_streaming_request(builder)
700 .instrument(span)
701 .await
702 }
703}