1use super::openai;
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 OneOrMany,
20 completion::{self, CompletionError, CompletionRequest},
21 impl_conversion_traits, json_utils, message,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25
26const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
30
31#[derive(Clone)]
32pub struct Client {
33 base_url: String,
34 api_key: String,
35 fine_tune_api_key: Option<String>,
36 http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("base_url", &self.base_url)
43 .field("http_client", &self.http_client)
44 .field("api_key", &"<REDACTED>")
45 .field("fine_tune_api_key", &"<REDACTED>")
46 .finish()
47 }
48}
49
50impl Client {
51 pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
53 Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
54 }
55
56 pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
58 Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
59 }
60
61 pub fn from_url_with_optional_key(
62 api_key: &str,
63 base_url: &str,
64 fine_tune_api_key: Option<&str>,
65 ) -> Self {
66 Self {
67 base_url: base_url.to_string(),
68 api_key: api_key.to_string(),
69 fine_tune_api_key: fine_tune_api_key.map(|x| x.to_string()),
70 http_client: reqwest::Client::builder()
71 .build()
72 .expect("Galadriel reqwest client should build"),
73 }
74 }
75
76 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
79 self.http_client = client;
80
81 self
82 }
83
84 fn post(&self, path: &str) -> reqwest::RequestBuilder {
85 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
86 let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
87
88 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
89 client = client.header("Fine-Tune-Authorization", fine_tune_key);
90 }
91
92 client
93 }
94}
95
96impl ProviderClient for Client {
97 fn from_env() -> Self {
101 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
102 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
103 Self::new(&api_key, fine_tune_api_key.as_deref())
104 }
105
106 fn from_val(input: crate::client::ProviderValue) -> Self {
107 let crate::client::ProviderValue::ApiKeyWithOptionalKey(api_key, fine_tune_key) = input
108 else {
109 panic!("Incorrect provider value type")
110 };
111 Self::new(&api_key, fine_tune_key.as_deref())
112 }
113}
114
115impl CompletionClient for Client {
116 type CompletionModel = CompletionModel;
117
118 fn completion_model(&self, model: &str) -> CompletionModel {
130 CompletionModel::new(self.clone(), model)
131 }
132}
133
134impl_conversion_traits!(
135 AsEmbeddings,
136 AsTranscription,
137 AsImageGeneration,
138 AsAudioGeneration for Client
139);
140
141#[derive(Debug, Deserialize)]
142struct ApiErrorResponse {
143 message: String,
144}
145
146#[derive(Debug, Deserialize)]
147#[serde(untagged)]
148enum ApiResponse<T> {
149 Ok(T),
150 Err(ApiErrorResponse),
151}
152
153#[derive(Clone, Debug, Deserialize)]
154pub struct Usage {
155 pub prompt_tokens: usize,
156 pub total_tokens: usize,
157}
158
159impl std::fmt::Display for Usage {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 write!(
162 f,
163 "Prompt tokens: {} Total tokens: {}",
164 self.prompt_tokens, self.total_tokens
165 )
166 }
167}
168
169pub const O1_PREVIEW: &str = "o1-preview";
174pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
176pub const O1_MINI: &str = "o1-mini";
178pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
180pub const GPT_4O: &str = "gpt-4o";
182pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
184pub const GPT_4_TURBO: &str = "gpt-4-turbo";
186pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
188pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
190pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
192pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
194pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
196pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
198pub const GPT_4: &str = "gpt-4";
200pub const GPT_4_0613: &str = "gpt-4-0613";
202pub const GPT_4_32K: &str = "gpt-4-32k";
204pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
206pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
208pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
210pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
212pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
214
215#[derive(Debug, Deserialize)]
216pub struct CompletionResponse {
217 pub id: String,
218 pub object: String,
219 pub created: u64,
220 pub model: String,
221 pub system_fingerprint: Option<String>,
222 pub choices: Vec<Choice>,
223 pub usage: Option<Usage>,
224}
225
226impl From<ApiErrorResponse> for CompletionError {
227 fn from(err: ApiErrorResponse) -> Self {
228 CompletionError::ProviderError(err.message)
229 }
230}
231
232impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
233 type Error = CompletionError;
234
235 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
236 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
237 CompletionError::ResponseError("Response contained no choices".to_owned())
238 })?;
239
240 let mut content = message
241 .content
242 .as_ref()
243 .map(|c| vec![completion::AssistantContent::text(c)])
244 .unwrap_or_default();
245
246 content.extend(message.tool_calls.iter().map(|call| {
247 completion::AssistantContent::tool_call(
248 &call.function.name,
249 &call.function.name,
250 call.function.arguments.clone(),
251 )
252 }));
253
254 let choice = OneOrMany::many(content).map_err(|_| {
255 CompletionError::ResponseError(
256 "Response contained no message or tool call (empty)".to_owned(),
257 )
258 })?;
259 let usage = response
260 .usage
261 .as_ref()
262 .map(|usage| completion::Usage {
263 input_tokens: usage.prompt_tokens as u64,
264 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
265 total_tokens: usage.total_tokens as u64,
266 })
267 .unwrap_or_default();
268
269 Ok(completion::CompletionResponse {
270 choice,
271 usage,
272 raw_response: response,
273 })
274 }
275}
276
277#[derive(Debug, Deserialize)]
278pub struct Choice {
279 pub index: usize,
280 pub message: Message,
281 pub logprobs: Option<serde_json::Value>,
282 pub finish_reason: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Message {
287 pub role: String,
288 pub content: Option<String>,
289 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
290 pub tool_calls: Vec<openai::ToolCall>,
291}
292
293impl TryFrom<Message> for message::Message {
294 type Error = message::MessageError;
295
296 fn try_from(message: Message) -> Result<Self, Self::Error> {
297 let tool_calls: Vec<message::ToolCall> = message
298 .tool_calls
299 .into_iter()
300 .map(|tool_call| tool_call.into())
301 .collect();
302
303 match message.role.as_str() {
304 "user" => Ok(Self::User {
305 content: OneOrMany::one(
306 message
307 .content
308 .map(|content| message::UserContent::text(&content))
309 .ok_or_else(|| {
310 message::MessageError::ConversionError("Empty user message".to_string())
311 })?,
312 ),
313 }),
314 "assistant" => Ok(Self::Assistant {
315 id: None,
316 content: OneOrMany::many(
317 tool_calls
318 .into_iter()
319 .map(message::AssistantContent::ToolCall)
320 .chain(
321 message
322 .content
323 .map(|content| message::AssistantContent::text(&content))
324 .into_iter(),
325 ),
326 )
327 .map_err(|_| {
328 message::MessageError::ConversionError("Empty assistant message".to_string())
329 })?,
330 }),
331 _ => Err(message::MessageError::ConversionError(format!(
332 "Unknown role: {}",
333 message.role
334 ))),
335 }
336 }
337}
338
339impl TryFrom<message::Message> for Message {
340 type Error = message::MessageError;
341
342 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
343 match message {
344 message::Message::User { content } => Ok(Self {
345 role: "user".to_string(),
346 content: content.iter().find_map(|c| match c {
347 message::UserContent::Text(text) => Some(text.text.clone()),
348 _ => None,
349 }),
350 tool_calls: vec![],
351 }),
352 message::Message::Assistant { content, .. } => {
353 let mut text_content: Option<String> = None;
354 let mut tool_calls = vec![];
355
356 for c in content.iter() {
357 match c {
358 message::AssistantContent::Text(text) => {
359 text_content = Some(
360 text_content
361 .map(|mut existing| {
362 existing.push('\n');
363 existing.push_str(&text.text);
364 existing
365 })
366 .unwrap_or_else(|| text.text.clone()),
367 );
368 }
369 message::AssistantContent::ToolCall(tool_call) => {
370 tool_calls.push(tool_call.clone().into());
371 }
372 }
373 }
374
375 Ok(Self {
376 role: "assistant".to_string(),
377 content: text_content,
378 tool_calls,
379 })
380 }
381 }
382 }
383}
384
385#[derive(Clone, Debug, Deserialize, Serialize)]
386pub struct ToolDefinition {
387 pub r#type: String,
388 pub function: completion::ToolDefinition,
389}
390
391impl From<completion::ToolDefinition> for ToolDefinition {
392 fn from(tool: completion::ToolDefinition) -> Self {
393 Self {
394 r#type: "function".into(),
395 function: tool,
396 }
397 }
398}
399
400#[derive(Debug, Deserialize)]
401pub struct Function {
402 pub name: String,
403 pub arguments: String,
404}
405
406#[derive(Clone)]
407pub struct CompletionModel {
408 client: Client,
409 pub model: String,
411}
412
413impl CompletionModel {
414 pub(crate) fn create_completion_request(
415 &self,
416 completion_request: CompletionRequest,
417 ) -> Result<Value, CompletionError> {
418 let mut partial_history = vec![];
420 if let Some(docs) = completion_request.normalized_documents() {
421 partial_history.push(docs);
422 }
423 partial_history.extend(completion_request.chat_history);
424
425 let mut full_history: Vec<Message> = match &completion_request.preamble {
427 Some(preamble) => vec![Message {
428 role: "system".to_string(),
429 content: Some(preamble.to_string()),
430 tool_calls: vec![],
431 }],
432 None => vec![],
433 };
434
435 full_history.extend(
437 partial_history
438 .into_iter()
439 .map(message::Message::try_into)
440 .collect::<Result<Vec<Message>, _>>()?,
441 );
442
443 let request = if completion_request.tools.is_empty() {
444 json!({
445 "model": self.model,
446 "messages": full_history,
447 "temperature": completion_request.temperature,
448 })
449 } else {
450 json!({
451 "model": self.model,
452 "messages": full_history,
453 "temperature": completion_request.temperature,
454 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
455 "tool_choice": "auto",
456 })
457 };
458
459 let request = if let Some(params) = completion_request.additional_params {
460 json_utils::merge(request, params)
461 } else {
462 request
463 };
464
465 Ok(request)
466 }
467}
468
469impl CompletionModel {
470 pub fn new(client: Client, model: &str) -> Self {
471 Self {
472 client,
473 model: model.to_string(),
474 }
475 }
476}
477
478impl completion::CompletionModel for CompletionModel {
479 type Response = CompletionResponse;
480 type StreamingResponse = openai::StreamingCompletionResponse;
481
482 #[cfg_attr(feature = "worker", worker::send)]
483 async fn completion(
484 &self,
485 completion_request: CompletionRequest,
486 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
487 let request = self.create_completion_request(completion_request)?;
488
489 let response = self
490 .client
491 .post("/chat/completions")
492 .json(&request)
493 .send()
494 .await?;
495
496 if response.status().is_success() {
497 let t = response.text().await?;
498 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
499
500 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
501 ApiResponse::Ok(response) => {
502 tracing::info!(target: "rig",
503 "Galadriel completion token usage: {:?}",
504 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
505 );
506 response.try_into()
507 }
508 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
509 }
510 } else {
511 Err(CompletionError::ProviderError(response.text().await?))
512 }
513 }
514
515 #[cfg_attr(feature = "worker", worker::send)]
516 async fn stream(
517 &self,
518 request: CompletionRequest,
519 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
520 let mut request = self.create_completion_request(request)?;
521
522 request = merge(
523 request,
524 json!({"stream": true, "stream_options": {"include_usage": true}}),
525 );
526
527 let builder = self.client.post("/chat/completions").json(&request);
528
529 send_compatible_streaming_request(builder).await
530 }
531}