1use super::openai;
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 impl_conversion_traits, json_utils, message, OneOrMany,
21};
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24
25const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
29
30#[derive(Clone, Debug)]
31pub struct Client {
32 base_url: String,
33 http_client: reqwest::Client,
34}
35
36impl Client {
37 pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
39 Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
40 }
41
42 pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
44 Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
45 }
46
47 pub fn from_url_with_optional_key(
48 api_key: &str,
49 base_url: &str,
50 fine_tune_api_key: Option<&str>,
51 ) -> Self {
52 Self {
53 base_url: base_url.to_string(),
54 http_client: reqwest::Client::builder()
55 .default_headers({
56 let mut headers = reqwest::header::HeaderMap::new();
57 headers.insert(
58 "Authorization",
59 format!("Bearer {api_key}")
60 .parse()
61 .expect("Bearer token should parse"),
62 );
63 if let Some(key) = fine_tune_api_key {
64 headers.insert(
65 "Fine-Tune-Authorization",
66 format!("Bearer {key}")
67 .parse()
68 .expect("Bearer token should parse"),
69 );
70 }
71 headers
72 })
73 .build()
74 .expect("Galadriel reqwest client should build"),
75 }
76 }
77
78 fn post(&self, path: &str) -> reqwest::RequestBuilder {
79 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
80 self.http_client.post(url)
81 }
82}
83
84impl ProviderClient for Client {
85 fn from_env() -> Self {
89 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
90 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
91 Self::new(&api_key, fine_tune_api_key.as_deref())
92 }
93}
94
95impl CompletionClient for Client {
96 type CompletionModel = CompletionModel;
97
98 fn completion_model(&self, model: &str) -> CompletionModel {
110 CompletionModel::new(self.clone(), model)
111 }
112}
113
114impl_conversion_traits!(
115 AsEmbeddings,
116 AsTranscription,
117 AsImageGeneration,
118 AsAudioGeneration for Client
119);
120
121#[derive(Debug, Deserialize)]
122struct ApiErrorResponse {
123 message: String,
124}
125
126#[derive(Debug, Deserialize)]
127#[serde(untagged)]
128enum ApiResponse<T> {
129 Ok(T),
130 Err(ApiErrorResponse),
131}
132
133#[derive(Clone, Debug, Deserialize)]
134pub struct Usage {
135 pub prompt_tokens: usize,
136 pub total_tokens: usize,
137}
138
139impl std::fmt::Display for Usage {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 write!(
142 f,
143 "Prompt tokens: {} Total tokens: {}",
144 self.prompt_tokens, self.total_tokens
145 )
146 }
147}
148
149pub const O1_PREVIEW: &str = "o1-preview";
154pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
156pub const O1_MINI: &str = "o1-mini";
158pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
160pub const GPT_4O: &str = "gpt-4o";
162pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
164pub const GPT_4_TURBO: &str = "gpt-4-turbo";
166pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
168pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
170pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
172pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
174pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
176pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
178pub const GPT_4: &str = "gpt-4";
180pub const GPT_4_0613: &str = "gpt-4-0613";
182pub const GPT_4_32K: &str = "gpt-4-32k";
184pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
186pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
188pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
190pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
192pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
194
195#[derive(Debug, Deserialize)]
196pub struct CompletionResponse {
197 pub id: String,
198 pub object: String,
199 pub created: u64,
200 pub model: String,
201 pub system_fingerprint: Option<String>,
202 pub choices: Vec<Choice>,
203 pub usage: Option<Usage>,
204}
205
206impl From<ApiErrorResponse> for CompletionError {
207 fn from(err: ApiErrorResponse) -> Self {
208 CompletionError::ProviderError(err.message)
209 }
210}
211
212impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
213 type Error = CompletionError;
214
215 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
216 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
217 CompletionError::ResponseError("Response contained no choices".to_owned())
218 })?;
219
220 let mut content = message
221 .content
222 .as_ref()
223 .map(|c| vec![completion::AssistantContent::text(c)])
224 .unwrap_or_default();
225
226 content.extend(message.tool_calls.iter().map(|call| {
227 completion::AssistantContent::tool_call(
228 &call.function.name,
229 &call.function.name,
230 call.function.arguments.clone(),
231 )
232 }));
233
234 let choice = OneOrMany::many(content).map_err(|_| {
235 CompletionError::ResponseError(
236 "Response contained no message or tool call (empty)".to_owned(),
237 )
238 })?;
239
240 Ok(completion::CompletionResponse {
241 choice,
242 raw_response: response,
243 })
244 }
245}
246
247#[derive(Debug, Deserialize)]
248pub struct Choice {
249 pub index: usize,
250 pub message: Message,
251 pub logprobs: Option<serde_json::Value>,
252 pub finish_reason: String,
253}
254
255#[derive(Debug, Serialize, Deserialize)]
256pub struct Message {
257 pub role: String,
258 pub content: Option<String>,
259 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
260 pub tool_calls: Vec<openai::ToolCall>,
261}
262
263impl TryFrom<Message> for message::Message {
264 type Error = message::MessageError;
265
266 fn try_from(message: Message) -> Result<Self, Self::Error> {
267 let tool_calls: Vec<message::ToolCall> = message
268 .tool_calls
269 .into_iter()
270 .map(|tool_call| tool_call.into())
271 .collect();
272
273 match message.role.as_str() {
274 "user" => Ok(Self::User {
275 content: OneOrMany::one(
276 message
277 .content
278 .map(|content| message::UserContent::text(&content))
279 .ok_or_else(|| {
280 message::MessageError::ConversionError("Empty user message".to_string())
281 })?,
282 ),
283 }),
284 "assistant" => Ok(Self::Assistant {
285 content: OneOrMany::many(
286 tool_calls
287 .into_iter()
288 .map(message::AssistantContent::ToolCall)
289 .chain(
290 message
291 .content
292 .map(|content| message::AssistantContent::text(&content))
293 .into_iter(),
294 ),
295 )
296 .map_err(|_| {
297 message::MessageError::ConversionError("Empty assistant message".to_string())
298 })?,
299 }),
300 _ => Err(message::MessageError::ConversionError(format!(
301 "Unknown role: {}",
302 message.role
303 ))),
304 }
305 }
306}
307
308impl TryFrom<message::Message> for Message {
309 type Error = message::MessageError;
310
311 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
312 match message {
313 message::Message::User { content } => Ok(Self {
314 role: "user".to_string(),
315 content: content.iter().find_map(|c| match c {
316 message::UserContent::Text(text) => Some(text.text.clone()),
317 _ => None,
318 }),
319 tool_calls: vec![],
320 }),
321 message::Message::Assistant { content } => {
322 let mut text_content: Option<String> = None;
323 let mut tool_calls = vec![];
324
325 for c in content.iter() {
326 match c {
327 message::AssistantContent::Text(text) => {
328 text_content = Some(
329 text_content
330 .map(|mut existing| {
331 existing.push('\n');
332 existing.push_str(&text.text);
333 existing
334 })
335 .unwrap_or_else(|| text.text.clone()),
336 );
337 }
338 message::AssistantContent::ToolCall(tool_call) => {
339 tool_calls.push(tool_call.clone().into());
340 }
341 }
342 }
343
344 Ok(Self {
345 role: "assistant".to_string(),
346 content: text_content,
347 tool_calls,
348 })
349 }
350 }
351 }
352}
353
354#[derive(Clone, Debug, Deserialize, Serialize)]
355pub struct ToolDefinition {
356 pub r#type: String,
357 pub function: completion::ToolDefinition,
358}
359
360impl From<completion::ToolDefinition> for ToolDefinition {
361 fn from(tool: completion::ToolDefinition) -> Self {
362 Self {
363 r#type: "function".into(),
364 function: tool,
365 }
366 }
367}
368
369#[derive(Debug, Deserialize)]
370pub struct Function {
371 pub name: String,
372 pub arguments: String,
373}
374
375#[derive(Clone)]
376pub struct CompletionModel {
377 client: Client,
378 pub model: String,
380}
381
382impl CompletionModel {
383 pub(crate) fn create_completion_request(
384 &self,
385 completion_request: CompletionRequest,
386 ) -> Result<Value, CompletionError> {
387 let mut partial_history = vec![];
389 if let Some(docs) = completion_request.normalized_documents() {
390 partial_history.push(docs);
391 }
392 partial_history.extend(completion_request.chat_history);
393
394 let mut full_history: Vec<Message> = match &completion_request.preamble {
396 Some(preamble) => vec![Message {
397 role: "system".to_string(),
398 content: Some(preamble.to_string()),
399 tool_calls: vec![],
400 }],
401 None => vec![],
402 };
403
404 full_history.extend(
406 partial_history
407 .into_iter()
408 .map(message::Message::try_into)
409 .collect::<Result<Vec<Message>, _>>()?,
410 );
411
412 let request = if completion_request.tools.is_empty() {
413 json!({
414 "model": self.model,
415 "messages": full_history,
416 "temperature": completion_request.temperature,
417 })
418 } else {
419 json!({
420 "model": self.model,
421 "messages": full_history,
422 "temperature": completion_request.temperature,
423 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
424 "tool_choice": "auto",
425 })
426 };
427
428 let request = if let Some(params) = completion_request.additional_params {
429 json_utils::merge(request, params)
430 } else {
431 request
432 };
433
434 Ok(request)
435 }
436}
437
438impl CompletionModel {
439 pub fn new(client: Client, model: &str) -> Self {
440 Self {
441 client,
442 model: model.to_string(),
443 }
444 }
445}
446
447impl completion::CompletionModel for CompletionModel {
448 type Response = CompletionResponse;
449 type StreamingResponse = openai::StreamingCompletionResponse;
450
451 #[cfg_attr(feature = "worker", worker::send)]
452 async fn completion(
453 &self,
454 completion_request: CompletionRequest,
455 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
456 let request = self.create_completion_request(completion_request)?;
457
458 let response = self
459 .client
460 .post("/chat/completions")
461 .json(&request)
462 .send()
463 .await?;
464
465 if response.status().is_success() {
466 let t = response.text().await?;
467 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
468
469 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
470 ApiResponse::Ok(response) => {
471 tracing::info!(target: "rig",
472 "Galadriel completion token usage: {:?}",
473 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
474 );
475 response.try_into()
476 }
477 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
478 }
479 } else {
480 Err(CompletionError::ProviderError(response.text().await?))
481 }
482 }
483
484 #[cfg_attr(feature = "worker", worker::send)]
485 async fn stream(
486 &self,
487 request: CompletionRequest,
488 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
489 let mut request = self.create_completion_request(request)?;
490
491 request = merge(
492 request,
493 json!({"stream": true, "stream_options": {"include_usage": true}}),
494 );
495
496 let builder = self.client.post("/chat/completions").json(&request);
497
498 send_compatible_streaming_request(builder).await
499 }
500}