1use super::openai;
14use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::message::MessageError;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 impl_conversion_traits, json_utils, message,
23};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26
27const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
31
32pub struct ClientBuilder<'a> {
33 api_key: &'a str,
34 fine_tune_api_key: Option<&'a str>,
35 base_url: &'a str,
36 http_client: Option<reqwest::Client>,
37}
38
39impl<'a> ClientBuilder<'a> {
40 pub fn new(api_key: &'a str) -> Self {
41 Self {
42 api_key,
43 fine_tune_api_key: None,
44 base_url: GALADRIEL_API_BASE_URL,
45 http_client: None,
46 }
47 }
48
49 pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self {
50 self.fine_tune_api_key = Some(fine_tune_api_key);
51 self
52 }
53
54 pub fn base_url(mut self, base_url: &'a str) -> Self {
55 self.base_url = base_url;
56 self
57 }
58
59 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60 self.http_client = Some(client);
61 self
62 }
63
64 pub fn build(self) -> Result<Client, ClientBuilderError> {
65 let http_client = if let Some(http_client) = self.http_client {
66 http_client
67 } else {
68 reqwest::Client::builder().build()?
69 };
70
71 Ok(Client {
72 base_url: self.base_url.to_string(),
73 api_key: self.api_key.to_string(),
74 fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()),
75 http_client,
76 })
77 }
78}
79#[derive(Clone)]
80pub struct Client {
81 base_url: String,
82 api_key: String,
83 fine_tune_api_key: Option<String>,
84 http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for Client {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Client")
90 .field("base_url", &self.base_url)
91 .field("http_client", &self.http_client)
92 .field("api_key", &"<REDACTED>")
93 .field("fine_tune_api_key", &"<REDACTED>")
94 .finish()
95 }
96}
97
98impl Client {
99 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
110 ClientBuilder::new(api_key)
111 }
112
113 pub fn new(api_key: &str) -> Self {
118 Self::builder(api_key)
119 .build()
120 .expect("Galadriel client should build")
121 }
122
123 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
124 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
125 let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
126
127 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
128 client = client.header("Fine-Tune-Authorization", fine_tune_key);
129 }
130
131 client
132 }
133}
134
135impl ProviderClient for Client {
136 fn from_env() -> Self {
140 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
141 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
142 let mut builder = Self::builder(&api_key);
143 if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
144 builder = builder.fine_tune_api_key(fine_tune_api_key);
145 }
146 builder.build().expect("Galadriel client should build")
147 }
148
149 fn from_val(input: crate::client::ProviderValue) -> Self {
150 let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
151 else {
152 panic!("Incorrect provider value type")
153 };
154 let mut builder = Self::builder(&api_key);
155 if let Some(fine_tune_key) = fine_tune_key.as_deref() {
156 builder = builder.fine_tune_api_key(fine_tune_key);
157 }
158 builder.build().expect("Galadriel client should build")
159 }
160}
161
162impl CompletionClient for Client {
163 type CompletionModel = CompletionModel;
164
165 fn completion_model(&self, model: &str) -> CompletionModel {
177 CompletionModel::new(self.clone(), model)
178 }
179}
180
181impl_conversion_traits!(
182 AsEmbeddings,
183 AsTranscription,
184 AsImageGeneration,
185 AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189struct ApiErrorResponse {
190 message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195enum ApiResponse<T> {
196 Ok(T),
197 Err(ApiErrorResponse),
198}
199
200#[derive(Clone, Debug, Deserialize, Serialize)]
201pub struct Usage {
202 pub prompt_tokens: usize,
203 pub total_tokens: usize,
204}
205
206impl std::fmt::Display for Usage {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(
209 f,
210 "Prompt tokens: {} Total tokens: {}",
211 self.prompt_tokens, self.total_tokens
212 )
213 }
214}
215
216pub const O1_PREVIEW: &str = "o1-preview";
221pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
223pub const O1_MINI: &str = "o1-mini";
225pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
227pub const GPT_4O: &str = "gpt-4o";
229pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
231pub const GPT_4_TURBO: &str = "gpt-4-turbo";
233pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
235pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
237pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
239pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
241pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
243pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
245pub const GPT_4: &str = "gpt-4";
247pub const GPT_4_0613: &str = "gpt-4-0613";
249pub const GPT_4_32K: &str = "gpt-4-32k";
251pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
253pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
255pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
257pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
259pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
261
262#[derive(Debug, Deserialize, Serialize)]
263pub struct CompletionResponse {
264 pub id: String,
265 pub object: String,
266 pub created: u64,
267 pub model: String,
268 pub system_fingerprint: Option<String>,
269 pub choices: Vec<Choice>,
270 pub usage: Option<Usage>,
271}
272
273impl From<ApiErrorResponse> for CompletionError {
274 fn from(err: ApiErrorResponse) -> Self {
275 CompletionError::ProviderError(err.message)
276 }
277}
278
279impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
280 type Error = CompletionError;
281
282 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
283 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
284 CompletionError::ResponseError("Response contained no choices".to_owned())
285 })?;
286
287 let mut content = message
288 .content
289 .as_ref()
290 .map(|c| vec![completion::AssistantContent::text(c)])
291 .unwrap_or_default();
292
293 content.extend(message.tool_calls.iter().map(|call| {
294 completion::AssistantContent::tool_call(
295 &call.function.name,
296 &call.function.name,
297 call.function.arguments.clone(),
298 )
299 }));
300
301 let choice = OneOrMany::many(content).map_err(|_| {
302 CompletionError::ResponseError(
303 "Response contained no message or tool call (empty)".to_owned(),
304 )
305 })?;
306 let usage = response
307 .usage
308 .as_ref()
309 .map(|usage| completion::Usage {
310 input_tokens: usage.prompt_tokens as u64,
311 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
312 total_tokens: usage.total_tokens as u64,
313 })
314 .unwrap_or_default();
315
316 Ok(completion::CompletionResponse {
317 choice,
318 usage,
319 raw_response: response,
320 })
321 }
322}
323
324#[derive(Debug, Deserialize, Serialize)]
325pub struct Choice {
326 pub index: usize,
327 pub message: Message,
328 pub logprobs: Option<serde_json::Value>,
329 pub finish_reason: String,
330}
331
332#[derive(Debug, Serialize, Deserialize)]
333pub struct Message {
334 pub role: String,
335 pub content: Option<String>,
336 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
337 pub tool_calls: Vec<openai::ToolCall>,
338}
339
340impl TryFrom<Message> for message::Message {
341 type Error = message::MessageError;
342
343 fn try_from(message: Message) -> Result<Self, Self::Error> {
344 let tool_calls: Vec<message::ToolCall> = message
345 .tool_calls
346 .into_iter()
347 .map(|tool_call| tool_call.into())
348 .collect();
349
350 match message.role.as_str() {
351 "user" => Ok(Self::User {
352 content: OneOrMany::one(
353 message
354 .content
355 .map(|content| message::UserContent::text(&content))
356 .ok_or_else(|| {
357 message::MessageError::ConversionError("Empty user message".to_string())
358 })?,
359 ),
360 }),
361 "assistant" => Ok(Self::Assistant {
362 id: None,
363 content: OneOrMany::many(
364 tool_calls
365 .into_iter()
366 .map(message::AssistantContent::ToolCall)
367 .chain(
368 message
369 .content
370 .map(|content| message::AssistantContent::text(&content))
371 .into_iter(),
372 ),
373 )
374 .map_err(|_| {
375 message::MessageError::ConversionError("Empty assistant message".to_string())
376 })?,
377 }),
378 _ => Err(message::MessageError::ConversionError(format!(
379 "Unknown role: {}",
380 message.role
381 ))),
382 }
383 }
384}
385
386impl TryFrom<message::Message> for Message {
387 type Error = message::MessageError;
388
389 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
390 match message {
391 message::Message::User { content } => Ok(Self {
392 role: "user".to_string(),
393 content: content.iter().find_map(|c| match c {
394 message::UserContent::Text(text) => Some(text.text.clone()),
395 _ => None,
396 }),
397 tool_calls: vec![],
398 }),
399 message::Message::Assistant { content, .. } => {
400 let mut text_content: Option<String> = None;
401 let mut tool_calls = vec![];
402
403 for c in content.iter() {
404 match c {
405 message::AssistantContent::Text(text) => {
406 text_content = Some(
407 text_content
408 .map(|mut existing| {
409 existing.push('\n');
410 existing.push_str(&text.text);
411 existing
412 })
413 .unwrap_or_else(|| text.text.clone()),
414 );
415 }
416 message::AssistantContent::ToolCall(tool_call) => {
417 tool_calls.push(tool_call.clone().into());
418 }
419 message::AssistantContent::Reasoning(_) => {
420 return Err(MessageError::ConversionError(
421 "Galadriel currently doesn't support reasoning.".into(),
422 ));
423 }
424 }
425 }
426
427 Ok(Self {
428 role: "assistant".to_string(),
429 content: text_content,
430 tool_calls,
431 })
432 }
433 }
434 }
435}
436
437#[derive(Clone, Debug, Deserialize, Serialize)]
438pub struct ToolDefinition {
439 pub r#type: String,
440 pub function: completion::ToolDefinition,
441}
442
443impl From<completion::ToolDefinition> for ToolDefinition {
444 fn from(tool: completion::ToolDefinition) -> Self {
445 Self {
446 r#type: "function".into(),
447 function: tool,
448 }
449 }
450}
451
452#[derive(Debug, Deserialize)]
453pub struct Function {
454 pub name: String,
455 pub arguments: String,
456}
457
458#[derive(Clone)]
459pub struct CompletionModel {
460 client: Client,
461 pub model: String,
463}
464
465impl CompletionModel {
466 pub(crate) fn create_completion_request(
467 &self,
468 completion_request: CompletionRequest,
469 ) -> Result<Value, CompletionError> {
470 let mut partial_history = vec![];
472 if let Some(docs) = completion_request.normalized_documents() {
473 partial_history.push(docs);
474 }
475 partial_history.extend(completion_request.chat_history);
476
477 let mut full_history: Vec<Message> = match &completion_request.preamble {
479 Some(preamble) => vec![Message {
480 role: "system".to_string(),
481 content: Some(preamble.to_string()),
482 tool_calls: vec![],
483 }],
484 None => vec![],
485 };
486
487 full_history.extend(
489 partial_history
490 .into_iter()
491 .map(message::Message::try_into)
492 .collect::<Result<Vec<Message>, _>>()?,
493 );
494
495 let request = if completion_request.tools.is_empty() {
496 json!({
497 "model": self.model,
498 "messages": full_history,
499 "temperature": completion_request.temperature,
500 })
501 } else {
502 json!({
503 "model": self.model,
504 "messages": full_history,
505 "temperature": completion_request.temperature,
506 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
507 "tool_choice": "auto",
508 })
509 };
510
511 let request = if let Some(params) = completion_request.additional_params {
512 json_utils::merge(request, params)
513 } else {
514 request
515 };
516
517 Ok(request)
518 }
519}
520
521impl CompletionModel {
522 pub fn new(client: Client, model: &str) -> Self {
523 Self {
524 client,
525 model: model.to_string(),
526 }
527 }
528}
529
530impl completion::CompletionModel for CompletionModel {
531 type Response = CompletionResponse;
532 type StreamingResponse = openai::StreamingCompletionResponse;
533
534 #[cfg_attr(feature = "worker", worker::send)]
535 async fn completion(
536 &self,
537 completion_request: CompletionRequest,
538 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
539 let request = self.create_completion_request(completion_request)?;
540
541 let response = self
542 .client
543 .post("/chat/completions")
544 .json(&request)
545 .send()
546 .await?;
547
548 if response.status().is_success() {
549 let t = response.text().await?;
550 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
551
552 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
553 ApiResponse::Ok(response) => {
554 tracing::info!(target: "rig",
555 "Galadriel completion token usage: {:?}",
556 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
557 );
558 response.try_into()
559 }
560 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
561 }
562 } else {
563 Err(CompletionError::ProviderError(response.text().await?))
564 }
565 }
566
567 #[cfg_attr(feature = "worker", worker::send)]
568 async fn stream(
569 &self,
570 request: CompletionRequest,
571 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
572 let mut request = self.create_completion_request(request)?;
573
574 request = merge(
575 request,
576 json!({"stream": true, "stream_options": {"include_usage": true}}),
577 );
578
579 let builder = self.client.post("/chat/completions").json(&request);
580
581 send_compatible_streaming_request(builder).await
582 }
583}