1use super::openai;
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 OneOrMany,
20 completion::{self, CompletionError, CompletionRequest},
21 impl_conversion_traits, json_utils, message,
22};
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25
26const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
30
31#[derive(Clone)]
32pub struct Client {
33 base_url: String,
34 api_key: String,
35 fine_tune_api_key: Option<String>,
36 http_client: reqwest::Client,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("base_url", &self.base_url)
43 .field("http_client", &self.http_client)
44 .field("api_key", &"<REDACTED>")
45 .field("fine_tune_api_key", &"<REDACTED>")
46 .finish()
47 }
48}
49
50impl Client {
51 pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
53 Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
54 }
55
56 pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
58 Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
59 }
60
61 pub fn from_url_with_optional_key(
62 api_key: &str,
63 base_url: &str,
64 fine_tune_api_key: Option<&str>,
65 ) -> Self {
66 Self {
67 base_url: base_url.to_string(),
68 api_key: api_key.to_string(),
69 fine_tune_api_key: fine_tune_api_key.map(|x| x.to_string()),
70 http_client: reqwest::Client::builder()
71 .build()
72 .expect("Galadriel reqwest client should build"),
73 }
74 }
75
76 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
79 self.http_client = client;
80
81 self
82 }
83
84 fn post(&self, path: &str) -> reqwest::RequestBuilder {
85 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
86 let mut client = self.http_client.post(url).bearer_auth(&self.api_key);
87
88 if let Some(fine_tune_key) = self.fine_tune_api_key.clone() {
89 client = client.header("Fine-Tune-Authorization", fine_tune_key);
90 }
91
92 client
93 }
94}
95
96impl ProviderClient for Client {
97 fn from_env() -> Self {
101 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
102 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
103 Self::new(&api_key, fine_tune_api_key.as_deref())
104 }
105}
106
107impl CompletionClient for Client {
108 type CompletionModel = CompletionModel;
109
110 fn completion_model(&self, model: &str) -> CompletionModel {
122 CompletionModel::new(self.clone(), model)
123 }
124}
125
126impl_conversion_traits!(
127 AsEmbeddings,
128 AsTranscription,
129 AsImageGeneration,
130 AsAudioGeneration for Client
131);
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135 message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141 Ok(T),
142 Err(ApiErrorResponse),
143}
144
145#[derive(Clone, Debug, Deserialize)]
146pub struct Usage {
147 pub prompt_tokens: usize,
148 pub total_tokens: usize,
149}
150
151impl std::fmt::Display for Usage {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(
154 f,
155 "Prompt tokens: {} Total tokens: {}",
156 self.prompt_tokens, self.total_tokens
157 )
158 }
159}
160
161pub const O1_PREVIEW: &str = "o1-preview";
166pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
168pub const O1_MINI: &str = "o1-mini";
170pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
172pub const GPT_4O: &str = "gpt-4o";
174pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
176pub const GPT_4_TURBO: &str = "gpt-4-turbo";
178pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
180pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
182pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
184pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
186pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
188pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
190pub const GPT_4: &str = "gpt-4";
192pub const GPT_4_0613: &str = "gpt-4-0613";
194pub const GPT_4_32K: &str = "gpt-4-32k";
196pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
198pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
200pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
202pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
204pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
206
207#[derive(Debug, Deserialize)]
208pub struct CompletionResponse {
209 pub id: String,
210 pub object: String,
211 pub created: u64,
212 pub model: String,
213 pub system_fingerprint: Option<String>,
214 pub choices: Vec<Choice>,
215 pub usage: Option<Usage>,
216}
217
218impl From<ApiErrorResponse> for CompletionError {
219 fn from(err: ApiErrorResponse) -> Self {
220 CompletionError::ProviderError(err.message)
221 }
222}
223
224impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
225 type Error = CompletionError;
226
227 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
228 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
229 CompletionError::ResponseError("Response contained no choices".to_owned())
230 })?;
231
232 let mut content = message
233 .content
234 .as_ref()
235 .map(|c| vec![completion::AssistantContent::text(c)])
236 .unwrap_or_default();
237
238 content.extend(message.tool_calls.iter().map(|call| {
239 completion::AssistantContent::tool_call(
240 &call.function.name,
241 &call.function.name,
242 call.function.arguments.clone(),
243 )
244 }));
245
246 let choice = OneOrMany::many(content).map_err(|_| {
247 CompletionError::ResponseError(
248 "Response contained no message or tool call (empty)".to_owned(),
249 )
250 })?;
251
252 Ok(completion::CompletionResponse {
253 choice,
254 raw_response: response,
255 })
256 }
257}
258
259#[derive(Debug, Deserialize)]
260pub struct Choice {
261 pub index: usize,
262 pub message: Message,
263 pub logprobs: Option<serde_json::Value>,
264 pub finish_reason: String,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct Message {
269 pub role: String,
270 pub content: Option<String>,
271 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
272 pub tool_calls: Vec<openai::ToolCall>,
273}
274
275impl TryFrom<Message> for message::Message {
276 type Error = message::MessageError;
277
278 fn try_from(message: Message) -> Result<Self, Self::Error> {
279 let tool_calls: Vec<message::ToolCall> = message
280 .tool_calls
281 .into_iter()
282 .map(|tool_call| tool_call.into())
283 .collect();
284
285 match message.role.as_str() {
286 "user" => Ok(Self::User {
287 content: OneOrMany::one(
288 message
289 .content
290 .map(|content| message::UserContent::text(&content))
291 .ok_or_else(|| {
292 message::MessageError::ConversionError("Empty user message".to_string())
293 })?,
294 ),
295 }),
296 "assistant" => Ok(Self::Assistant {
297 id: None,
298 content: OneOrMany::many(
299 tool_calls
300 .into_iter()
301 .map(message::AssistantContent::ToolCall)
302 .chain(
303 message
304 .content
305 .map(|content| message::AssistantContent::text(&content))
306 .into_iter(),
307 ),
308 )
309 .map_err(|_| {
310 message::MessageError::ConversionError("Empty assistant message".to_string())
311 })?,
312 }),
313 _ => Err(message::MessageError::ConversionError(format!(
314 "Unknown role: {}",
315 message.role
316 ))),
317 }
318 }
319}
320
321impl TryFrom<message::Message> for Message {
322 type Error = message::MessageError;
323
324 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
325 match message {
326 message::Message::User { content } => Ok(Self {
327 role: "user".to_string(),
328 content: content.iter().find_map(|c| match c {
329 message::UserContent::Text(text) => Some(text.text.clone()),
330 _ => None,
331 }),
332 tool_calls: vec![],
333 }),
334 message::Message::Assistant { content, .. } => {
335 let mut text_content: Option<String> = None;
336 let mut tool_calls = vec![];
337
338 for c in content.iter() {
339 match c {
340 message::AssistantContent::Text(text) => {
341 text_content = Some(
342 text_content
343 .map(|mut existing| {
344 existing.push('\n');
345 existing.push_str(&text.text);
346 existing
347 })
348 .unwrap_or_else(|| text.text.clone()),
349 );
350 }
351 message::AssistantContent::ToolCall(tool_call) => {
352 tool_calls.push(tool_call.clone().into());
353 }
354 }
355 }
356
357 Ok(Self {
358 role: "assistant".to_string(),
359 content: text_content,
360 tool_calls,
361 })
362 }
363 }
364 }
365}
366
367#[derive(Clone, Debug, Deserialize, Serialize)]
368pub struct ToolDefinition {
369 pub r#type: String,
370 pub function: completion::ToolDefinition,
371}
372
373impl From<completion::ToolDefinition> for ToolDefinition {
374 fn from(tool: completion::ToolDefinition) -> Self {
375 Self {
376 r#type: "function".into(),
377 function: tool,
378 }
379 }
380}
381
382#[derive(Debug, Deserialize)]
383pub struct Function {
384 pub name: String,
385 pub arguments: String,
386}
387
388#[derive(Clone)]
389pub struct CompletionModel {
390 client: Client,
391 pub model: String,
393}
394
395impl CompletionModel {
396 pub(crate) fn create_completion_request(
397 &self,
398 completion_request: CompletionRequest,
399 ) -> Result<Value, CompletionError> {
400 let mut partial_history = vec![];
402 if let Some(docs) = completion_request.normalized_documents() {
403 partial_history.push(docs);
404 }
405 partial_history.extend(completion_request.chat_history);
406
407 let mut full_history: Vec<Message> = match &completion_request.preamble {
409 Some(preamble) => vec![Message {
410 role: "system".to_string(),
411 content: Some(preamble.to_string()),
412 tool_calls: vec![],
413 }],
414 None => vec![],
415 };
416
417 full_history.extend(
419 partial_history
420 .into_iter()
421 .map(message::Message::try_into)
422 .collect::<Result<Vec<Message>, _>>()?,
423 );
424
425 let request = if completion_request.tools.is_empty() {
426 json!({
427 "model": self.model,
428 "messages": full_history,
429 "temperature": completion_request.temperature,
430 })
431 } else {
432 json!({
433 "model": self.model,
434 "messages": full_history,
435 "temperature": completion_request.temperature,
436 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
437 "tool_choice": "auto",
438 })
439 };
440
441 let request = if let Some(params) = completion_request.additional_params {
442 json_utils::merge(request, params)
443 } else {
444 request
445 };
446
447 Ok(request)
448 }
449}
450
451impl CompletionModel {
452 pub fn new(client: Client, model: &str) -> Self {
453 Self {
454 client,
455 model: model.to_string(),
456 }
457 }
458}
459
460impl completion::CompletionModel for CompletionModel {
461 type Response = CompletionResponse;
462 type StreamingResponse = openai::StreamingCompletionResponse;
463
464 #[cfg_attr(feature = "worker", worker::send)]
465 async fn completion(
466 &self,
467 completion_request: CompletionRequest,
468 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
469 let request = self.create_completion_request(completion_request)?;
470
471 let response = self
472 .client
473 .post("/chat/completions")
474 .json(&request)
475 .send()
476 .await?;
477
478 if response.status().is_success() {
479 let t = response.text().await?;
480 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
481
482 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
483 ApiResponse::Ok(response) => {
484 tracing::info!(target: "rig",
485 "Galadriel completion token usage: {:?}",
486 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
487 );
488 response.try_into()
489 }
490 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
491 }
492 } else {
493 Err(CompletionError::ProviderError(response.text().await?))
494 }
495 }
496
497 #[cfg_attr(feature = "worker", worker::send)]
498 async fn stream(
499 &self,
500 request: CompletionRequest,
501 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
502 let mut request = self.create_completion_request(request)?;
503
504 request = merge(
505 request,
506 json!({"stream": true, "stream_options": {"include_usage": true}}),
507 );
508
509 let builder = self.client.post("/chat/completions").json(&request);
510
511 send_compatible_streaming_request(builder).await
512 }
513}