1use super::openai;
14use crate::json_utils::merge;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::{StreamingCompletionModel, StreamingResult};
17use crate::{
18 agent::AgentBuilder,
19 completion::{self, CompletionError, CompletionRequest},
20 extractor::ExtractorBuilder,
21 json_utils, message, OneOrMany,
22};
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::{json, Value};
26
27const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
31
32#[derive(Clone)]
33pub struct Client {
34 base_url: String,
35 http_client: reqwest::Client,
36}
37
38impl Client {
39 pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
41 Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
42 }
43
44 pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
46 Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
47 }
48
49 pub fn from_url_with_optional_key(
50 api_key: &str,
51 base_url: &str,
52 fine_tune_api_key: Option<&str>,
53 ) -> Self {
54 Self {
55 base_url: base_url.to_string(),
56 http_client: reqwest::Client::builder()
57 .default_headers({
58 let mut headers = reqwest::header::HeaderMap::new();
59 headers.insert(
60 "Authorization",
61 format!("Bearer {}", api_key)
62 .parse()
63 .expect("Bearer token should parse"),
64 );
65 if let Some(key) = fine_tune_api_key {
66 headers.insert(
67 "Fine-Tune-Authorization",
68 format!("Bearer {}", key)
69 .parse()
70 .expect("Bearer token should parse"),
71 );
72 }
73 headers
74 })
75 .build()
76 .expect("Galadriel reqwest client should build"),
77 }
78 }
79
80 pub fn from_env() -> Self {
84 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
85 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
86 Self::new(&api_key, fine_tune_api_key.as_deref())
87 }
88 fn post(&self, path: &str) -> reqwest::RequestBuilder {
89 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
90 self.http_client.post(url)
91 }
92
93 pub fn completion_model(&self, model: &str) -> CompletionModel {
105 CompletionModel::new(self.clone(), model)
106 }
107
108 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
123 AgentBuilder::new(self.completion_model(model))
124 }
125
126 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
128 &self,
129 model: &str,
130 ) -> ExtractorBuilder<T, CompletionModel> {
131 ExtractorBuilder::new(self.completion_model(model))
132 }
133}
134
135#[derive(Debug, Deserialize)]
136struct ApiErrorResponse {
137 message: String,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(untagged)]
142enum ApiResponse<T> {
143 Ok(T),
144 Err(ApiErrorResponse),
145}
146
147#[derive(Clone, Debug, Deserialize)]
148pub struct Usage {
149 pub prompt_tokens: usize,
150 pub total_tokens: usize,
151}
152
153impl std::fmt::Display for Usage {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 write!(
156 f,
157 "Prompt tokens: {} Total tokens: {}",
158 self.prompt_tokens, self.total_tokens
159 )
160 }
161}
162
163pub const O1_PREVIEW: &str = "o1-preview";
168pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
170pub const O1_MINI: &str = "o1-mini";
172pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
174pub const GPT_4O: &str = "gpt-4o";
176pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
178pub const GPT_4_TURBO: &str = "gpt-4-turbo";
180pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
182pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
184pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
186pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
188pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
190pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
192pub const GPT_4: &str = "gpt-4";
194pub const GPT_4_0613: &str = "gpt-4-0613";
196pub const GPT_4_32K: &str = "gpt-4-32k";
198pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
200pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
202pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
204pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
206pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
208
209#[derive(Debug, Deserialize)]
210pub struct CompletionResponse {
211 pub id: String,
212 pub object: String,
213 pub created: u64,
214 pub model: String,
215 pub system_fingerprint: Option<String>,
216 pub choices: Vec<Choice>,
217 pub usage: Option<Usage>,
218}
219
220impl From<ApiErrorResponse> for CompletionError {
221 fn from(err: ApiErrorResponse) -> Self {
222 CompletionError::ProviderError(err.message)
223 }
224}
225
226impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
227 type Error = CompletionError;
228
229 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
230 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
231 CompletionError::ResponseError("Response contained no choices".to_owned())
232 })?;
233
234 let mut content = message
235 .content
236 .as_ref()
237 .map(|c| vec![completion::AssistantContent::text(c)])
238 .unwrap_or_default();
239
240 content.extend(message.tool_calls.iter().map(|call| {
241 completion::AssistantContent::tool_call(
242 &call.function.name,
243 &call.function.name,
244 call.function.arguments.clone(),
245 )
246 }));
247
248 let choice = OneOrMany::many(content).map_err(|_| {
249 CompletionError::ResponseError(
250 "Response contained no message or tool call (empty)".to_owned(),
251 )
252 })?;
253
254 Ok(completion::CompletionResponse {
255 choice,
256 raw_response: response,
257 })
258 }
259}
260
261#[derive(Debug, Deserialize)]
262pub struct Choice {
263 pub index: usize,
264 pub message: Message,
265 pub logprobs: Option<serde_json::Value>,
266 pub finish_reason: String,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270pub struct Message {
271 pub role: String,
272 pub content: Option<String>,
273 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
274 pub tool_calls: Vec<openai::ToolCall>,
275}
276
277impl TryFrom<Message> for message::Message {
278 type Error = message::MessageError;
279
280 fn try_from(message: Message) -> Result<Self, Self::Error> {
281 let tool_calls: Vec<message::ToolCall> = message
282 .tool_calls
283 .into_iter()
284 .map(|tool_call| tool_call.into())
285 .collect();
286
287 match message.role.as_str() {
288 "user" => Ok(Self::User {
289 content: OneOrMany::one(
290 message
291 .content
292 .map(|content| message::UserContent::text(&content))
293 .ok_or_else(|| {
294 message::MessageError::ConversionError("Empty user message".to_string())
295 })?,
296 ),
297 }),
298 "assistant" => Ok(Self::Assistant {
299 content: OneOrMany::many(
300 tool_calls
301 .into_iter()
302 .map(message::AssistantContent::ToolCall)
303 .chain(
304 message
305 .content
306 .map(|content| message::AssistantContent::text(&content))
307 .into_iter(),
308 ),
309 )
310 .map_err(|_| {
311 message::MessageError::ConversionError("Empty assistant message".to_string())
312 })?,
313 }),
314 _ => Err(message::MessageError::ConversionError(format!(
315 "Unknown role: {}",
316 message.role
317 ))),
318 }
319 }
320}
321
322impl TryFrom<message::Message> for Message {
323 type Error = message::MessageError;
324
325 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
326 match message {
327 message::Message::User { content } => Ok(Self {
328 role: "user".to_string(),
329 content: content.iter().find_map(|c| match c {
330 message::UserContent::Text(text) => Some(text.text.clone()),
331 _ => None,
332 }),
333 tool_calls: vec![],
334 }),
335 message::Message::Assistant { content } => {
336 let mut text_content: Option<String> = None;
337 let mut tool_calls = vec![];
338
339 for c in content.iter() {
340 match c {
341 message::AssistantContent::Text(text) => {
342 text_content = Some(
343 text_content
344 .map(|mut existing| {
345 existing.push('\n');
346 existing.push_str(&text.text);
347 existing
348 })
349 .unwrap_or_else(|| text.text.clone()),
350 );
351 }
352 message::AssistantContent::ToolCall(tool_call) => {
353 tool_calls.push(tool_call.clone().into());
354 }
355 }
356 }
357
358 Ok(Self {
359 role: "assistant".to_string(),
360 content: text_content,
361 tool_calls,
362 })
363 }
364 }
365 }
366}
367
368#[derive(Clone, Debug, Deserialize, Serialize)]
369pub struct ToolDefinition {
370 pub r#type: String,
371 pub function: completion::ToolDefinition,
372}
373
374impl From<completion::ToolDefinition> for ToolDefinition {
375 fn from(tool: completion::ToolDefinition) -> Self {
376 Self {
377 r#type: "function".into(),
378 function: tool,
379 }
380 }
381}
382
383#[derive(Debug, Deserialize)]
384pub struct Function {
385 pub name: String,
386 pub arguments: String,
387}
388
389#[derive(Clone)]
390pub struct CompletionModel {
391 client: Client,
392 pub model: String,
394}
395
396impl CompletionModel {
397 pub(crate) fn create_completion_request(
398 &self,
399 completion_request: CompletionRequest,
400 ) -> Result<Value, CompletionError> {
401 let mut partial_history = vec![];
403 if let Some(docs) = completion_request.normalized_documents() {
404 partial_history.push(docs);
405 }
406 partial_history.extend(completion_request.chat_history);
407
408 let mut full_history: Vec<Message> = match &completion_request.preamble {
410 Some(preamble) => vec![Message {
411 role: "system".to_string(),
412 content: Some(preamble.to_string()),
413 tool_calls: vec![],
414 }],
415 None => vec![],
416 };
417
418 full_history.extend(
420 partial_history
421 .into_iter()
422 .map(message::Message::try_into)
423 .collect::<Result<Vec<Message>, _>>()?,
424 );
425
426 let request = if completion_request.tools.is_empty() {
427 json!({
428 "model": self.model,
429 "messages": full_history,
430 "temperature": completion_request.temperature,
431 })
432 } else {
433 json!({
434 "model": self.model,
435 "messages": full_history,
436 "temperature": completion_request.temperature,
437 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
438 "tool_choice": "auto",
439 })
440 };
441
442 let request = if let Some(params) = completion_request.additional_params {
443 json_utils::merge(request, params)
444 } else {
445 request
446 };
447
448 Ok(request)
449 }
450}
451
452impl CompletionModel {
453 pub fn new(client: Client, model: &str) -> Self {
454 Self {
455 client,
456 model: model.to_string(),
457 }
458 }
459}
460
461impl completion::CompletionModel for CompletionModel {
462 type Response = CompletionResponse;
463
464 #[cfg_attr(feature = "worker", worker::send)]
465 async fn completion(
466 &self,
467 completion_request: CompletionRequest,
468 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
469 let request = self.create_completion_request(completion_request)?;
470
471 let response = self
472 .client
473 .post("/chat/completions")
474 .json(&request)
475 .send()
476 .await?;
477
478 if response.status().is_success() {
479 let t = response.text().await?;
480 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
481
482 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
483 ApiResponse::Ok(response) => {
484 tracing::info!(target: "rig",
485 "Galadriel completion token usage: {:?}",
486 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
487 );
488 response.try_into()
489 }
490 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
491 }
492 } else {
493 Err(CompletionError::ProviderError(response.text().await?))
494 }
495 }
496}
497
498impl StreamingCompletionModel for CompletionModel {
499 async fn stream(&self, request: CompletionRequest) -> Result<StreamingResult, CompletionError> {
500 let mut request = self.create_completion_request(request)?;
501
502 request = merge(request, json!({"stream": true}));
503
504 let builder = self.client.post("/chat/completions").json(&request);
505
506 send_compatible_streaming_request(builder).await
507 }
508}