1use super::openai;
14use crate::client::{
15 ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
16};
17use crate::json_utils::merge;
18use crate::message::MessageError;
19use crate::providers::openai::send_compatible_streaming_request;
20use crate::streaming::StreamingCompletionResponse;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 impl_conversion_traits, json_utils, message,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28
29const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
33
34pub struct ClientBuilder<'a> {
35 api_key: &'a str,
36 fine_tune_api_key: Option<&'a str>,
37 base_url: &'a str,
38 http_client: Option<reqwest::Client>,
39}
40
41impl<'a> ClientBuilder<'a> {
42 pub fn new(api_key: &'a str) -> Self {
43 Self {
44 api_key,
45 fine_tune_api_key: None,
46 base_url: GALADRIEL_API_BASE_URL,
47 http_client: None,
48 }
49 }
50
51 pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
52 self.fine_tune_api_key = Some(fine_tune_api_key);
53 self
54 }
55
56 pub fn base_url(mut self, base_url: &'a str) -> Self {
57 self.base_url = base_url;
58 self
59 }
60
61 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
62 self.http_client = Some(client);
63 self
64 }
65
66 pub fn build(self) -> Result<Client, ClientBuilderError> {
67 let http_client = if let Some(http_client) = self.http_client {
68 http_client
69 } else {
70 reqwest::Client::builder().build()?
71 };
72
73 Ok(Client {
74 base_url: self.base_url.to_string(),
75 api_key: self.api_key.to_string(),
76 fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
77 http_client,
78 })
79 }
80}
81#[derive(Clone)]
82pub struct Client {
83 base_url: String,
84 api_key: String,
85 fine_tune_api_key: Option<String>,
86 http_client: reqwest::Client,
87}
88
89impl std::fmt::Debug for Client {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("Client")
92 .field("base_url", &self.base_url)
93 .field("http_client", &self.http_client)
94 .field("api_key", &"<REDACTED>")
95 .field("fine_tune_api_key", &"<REDACTED>")
96 .finish()
97 }
98}
99
100impl Client {
101 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
112 ClientBuilder::new(api_key)
113 }
114
115 pub fn new(api_key: &str) -> Self {
120 Self::builder(api_key)
121 .build()
122 .expect("Galadriel client should build")
123 }
124
125 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
126 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
127 let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
128
129 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
130 client = client.header("Fine-Tune-Authorization", fine_tune_key);
131 }
132
133 client
134 }
135}
136
137impl ProviderClient for Client {
138 fn from_env() -> Self {
142 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
143 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
144 let mut builder = Self::builder(&api_key);
145 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
146 builder = builder.fine_tune_api_key(fine_tune_api_key);
147 }
148 builder.build().expect("Galadriel client should build")
149 }
150
151 fn from_val(input: crate::client::ProviderValue) -> Self {
152 let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
153 else {
154 panic!("Incorrect provider value type")
155 };
156 let mut builder = Self::builder(&api_key);
157 if let Some(fine_tune_key) = fine_tune_key.as_deref() {
158 builder = builder.fine_tune_api_key(fine_tune_key);
159 }
160 builder.build().expect("Galadriel client should build")
161 }
162}
163
164impl CompletionClient for Client {
165 type CompletionModel = CompletionModel;
166
167 fn completion_model(&self, model: &str) -> CompletionModel {
179 CompletionModel::new(self.clone(), model)
180 }
181}
182
183impl VerifyClient for Client {
184 #[cfg_attr(feature = "worker", worker::send)]
185 async fn verify(&self) -> Result<(), VerifyError> {
186 Ok(())
188 }
189}
190
191impl_conversion_traits!(
192 AsEmbeddings,
193 AsTranscription,
194 AsImageGeneration,
195 AsAudioGeneration for Client
196);
197
198#[derive(Debug, Deserialize)]
199struct ApiErrorResponse {
200 message: String,
201}
202
203#[derive(Debug, Deserialize)]
204#[serde(untagged)]
205enum ApiResponse<T> {
206 Ok(T),
207 Err(ApiErrorResponse),
208}
209
210#[derive(Clone, Debug, Deserialize, Serialize)]
211pub struct Usage {
212 pub prompt_tokens: usize,
213 pub total_tokens: usize,
214}
215
216impl std::fmt::Display for Usage {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 write!(
219 f,
220 "Prompt tokens: {} Total tokens: {}",
221 self.prompt_tokens, self.total_tokens
222 )
223 }
224}
225
226pub const O1_PREVIEW: &str = "o1-preview";
231pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
233pub const O1_MINI: &str = "o1-mini";
235pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
237pub const GPT_4O: &str = "gpt-4o";
239pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
241pub const GPT_4_TURBO: &str = "gpt-4-turbo";
243pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
245pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
247pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
249pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
251pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
253pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
255pub const GPT_4: &str = "gpt-4";
257pub const GPT_4_0613: &str = "gpt-4-0613";
259pub const GPT_4_32K: &str = "gpt-4-32k";
261pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
263pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
265pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
267pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
269pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
271
272#[derive(Debug, Deserialize, Serialize)]
273pub struct CompletionResponse {
274 pub id: String,
275 pub object: String,
276 pub created: u64,
277 pub model: String,
278 pub system_fingerprint: Option<String>,
279 pub choices: Vec<Choice>,
280 pub usage: Option<Usage>,
281}
282
283impl From<ApiErrorResponse> for CompletionError {
284 fn from(err: ApiErrorResponse) -> Self {
285 CompletionError::ProviderError(err.message)
286 }
287}
288
289impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
290 type Error = CompletionError;
291
292 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
293 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
294 CompletionError::ResponseError("Response contained no choices".to_owned())
295 })?;
296
297 let mut content = message
298 .content
299 .as_ref()
300 .map(|c| vec![completion::AssistantContent::text(c)])
301 .unwrap_or_default();
302
303 content.extend(message.tool_calls.iter().map(|call| {
304 completion::AssistantContent::tool_call(
305 &call.function.name,
306 &call.function.name,
307 call.function.arguments.clone(),
308 )
309 }));
310
311 let choice = OneOrMany::many(content).map_err(|_| {
312 CompletionError::ResponseError(
313 "Response contained no message or tool call (empty)".to_owned(),
314 )
315 })?;
316 let usage = response
317 .usage
318 .as_ref()
319 .map(|usage| completion::Usage {
320 input_tokens: usage.prompt_tokens as u64,
321 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
322 total_tokens: usage.total_tokens as u64,
323 })
324 .unwrap_or_default();
325
326 Ok(completion::CompletionResponse {
327 choice,
328 usage,
329 raw_response: response,
330 })
331 }
332}
333
334#[derive(Debug, Deserialize, Serialize)]
335pub struct Choice {
336 pub index: usize,
337 pub message: Message,
338 pub logprobs: Option<serde_json::Value>,
339 pub finish_reason: String,
340}
341
342#[derive(Debug, Serialize, Deserialize)]
343pub struct Message {
344 pub role: String,
345 pub content: Option<String>,
346 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
347 pub tool_calls: Vec<openai::ToolCall>,
348}
349
350impl TryFrom<Message> for message::Message {
351 type Error = message::MessageError;
352
353 fn try_from(message: Message) -> Result<Self, Self::Error> {
354 let tool_calls: Vec<message::ToolCall> = message
355 .tool_calls
356 .into_iter()
357 .map(|tool_call| tool_call.into())
358 .collect();
359
360 match message.role.as_str() {
361 "user" => Ok(Self::User {
362 content: OneOrMany::one(
363 message
364 .content
365 .map(|content| message::UserContent::text(&content))
366 .ok_or_else(|| {
367 message::MessageError::ConversionError("Empty user message".to_string())
368 })?,
369 ),
370 }),
371 "assistant" => Ok(Self::Assistant {
372 id: None,
373 content: OneOrMany::many(
374 tool_calls
375 .into_iter()
376 .map(message::AssistantContent::ToolCall)
377 .chain(
378 message
379 .content
380 .map(|content| message::AssistantContent::text(&content))
381 .into_iter(),
382 ),
383 )
384 .map_err(|_| {
385 message::MessageError::ConversionError("Empty assistant message".to_string())
386 })?,
387 }),
388 _ => Err(message::MessageError::ConversionError(format!(
389 "Unknown role: {}",
390 message.role
391 ))),
392 }
393 }
394}
395
396impl TryFrom<message::Message> for Message {
397 type Error = message::MessageError;
398
399 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
400 match message {
401 message::Message::User { content } => Ok(Self {
402 role: "user".to_string(),
403 content: content.iter().find_map(|c| match c {
404 message::UserContent::Text(text) => Some(text.text.clone()),
405 _ => None,
406 }),
407 tool_calls: vec![],
408 }),
409 message::Message::Assistant { content, .. } => {
410 let mut text_content: Option<String> = None;
411 let mut tool_calls = vec![];
412
413 for c in content.iter() {
414 match c {
415 message::AssistantContent::Text(text) => {
416 text_content = Some(
417 text_content
418 .map(|mut existing| {
419 existing.push('\n');
420 existing.push_str(&text.text);
421 existing
422 })
423 .unwrap_or_else(|| text.text.clone()),
424 );
425 }
426 message::AssistantContent::ToolCall(tool_call) => {
427 tool_calls.push(tool_call.clone().into());
428 }
429 message::AssistantContent::Reasoning(_) => {
430 return Err(MessageError::ConversionError(
431 "Galadriel currently doesn't support reasoning.".into(),
432 ));
433 }
434 }
435 }
436
437 Ok(Self {
438 role: "assistant".to_string(),
439 content: text_content,
440 tool_calls,
441 })
442 }
443 }
444 }
445}
446
447#[derive(Clone, Debug, Deserialize, Serialize)]
448pub struct ToolDefinition {
449 pub r#type: String,
450 pub function: completion::ToolDefinition,
451}
452
453impl From<completion::ToolDefinition> for ToolDefinition {
454 fn from(tool: completion::ToolDefinition) -> Self {
455 Self {
456 r#type: "function".into(),
457 function: tool,
458 }
459 }
460}
461
462#[derive(Debug, Deserialize)]
463pub struct Function {
464 pub name: String,
465 pub arguments: String,
466}
467
468#[derive(Clone)]
469pub struct CompletionModel {
470 client: Client,
471 pub model: String,
473}
474
475impl CompletionModel {
476 pub(crate) fn create_completion_request(
477 &self,
478 completion_request: CompletionRequest,
479 ) -> Result<Value, CompletionError> {
480 let mut partial_history = vec![];
482 if let Some(docs) = completion_request.normalized_documents() {
483 partial_history.push(docs);
484 }
485 partial_history.extend(completion_request.chat_history);
486
487 let mut full_history: Vec<Message> = match &completion_request.preamble {
489 Some(preamble) => vec![Message {
490 role: "system".to_string(),
491 content: Some(preamble.to_string()),
492 tool_calls: vec![],
493 }],
494 None => vec![],
495 };
496
497 full_history.extend(
499 partial_history
500 .into_iter()
501 .map(message::Message::try_into)
502 .collect::<Result<Vec<Message>, _>>()?,
503 );
504
505 let request = if completion_request.tools.is_empty() {
506 json!({
507 "model": self.model,
508 "messages": full_history,
509 "temperature": completion_request.temperature,
510 })
511 } else {
512 json!({
513 "model": self.model,
514 "messages": full_history,
515 "temperature": completion_request.temperature,
516 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
517 "tool_choice": "auto",
518 })
519 };
520
521 let request = if let Some(params) = completion_request.additional_params {
522 json_utils::merge(request, params)
523 } else {
524 request
525 };
526
527 Ok(request)
528 }
529}
530
531impl CompletionModel {
532 pub fn new(client: Client, model: &str) -> Self {
533 Self {
534 client,
535 model: model.to_string(),
536 }
537 }
538}
539
540impl completion::CompletionModel for CompletionModel {
541 type Response = CompletionResponse;
542 type StreamingResponse = openai::StreamingCompletionResponse;
543
544 #[cfg_attr(feature = "worker", worker::send)]
545 async fn completion(
546 &self,
547 completion_request: CompletionRequest,
548 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
549 let request = self.create_completion_request(completion_request)?;
550
551 let response = self
552 .client
553 .post("/chat/completions")
554 .json(&request)
555 .send()
556 .await?;
557
558 if response.status().is_success() {
559 let t = response.text().await?;
560 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
561
562 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
563 ApiResponse::Ok(response) => {
564 tracing::info!(target: "rig",
565 "Galadriel completion token usage: {:?}",
566 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
567 );
568 response.try_into()
569 }
570 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
571 }
572 } else {
573 Err(CompletionError::ProviderError(response.text().await?))
574 }
575 }
576
577 #[cfg_attr(feature = "worker", worker::send)]
578 async fn stream(
579 &self,
580 request: CompletionRequest,
581 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
582 let mut request = self.create_completion_request(request)?;
583
584 request = merge(
585 request,
586 json!({"stream": true, "stream_options": {"include_usage": true}}),
587 );
588
589 let builder = self.client.post("/chat/completions").json(&request);
590
591 send_compatible_streaming_request(builder).await
592 }
593}