1use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError, CompletionRequest},
16 extractor::ExtractorBuilder,
17 json_utils, message, OneOrMany,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23use super::openai;
24
25const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
29
30#[derive(Clone)]
31pub struct Client {
32 base_url: String,
33 http_client: reqwest::Client,
34}
35
36impl Client {
37 pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
39 Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
40 }
41
42 pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
44 Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
45 }
46
47 pub fn from_url_with_optional_key(
48 api_key: &str,
49 base_url: &str,
50 fine_tune_api_key: Option<&str>,
51 ) -> Self {
52 Self {
53 base_url: base_url.to_string(),
54 http_client: reqwest::Client::builder()
55 .default_headers({
56 let mut headers = reqwest::header::HeaderMap::new();
57 headers.insert(
58 "Authorization",
59 format!("Bearer {}", api_key)
60 .parse()
61 .expect("Bearer token should parse"),
62 );
63 if let Some(key) = fine_tune_api_key {
64 headers.insert(
65 "Fine-Tune-Authorization",
66 format!("Bearer {}", key)
67 .parse()
68 .expect("Bearer token should parse"),
69 );
70 }
71 headers
72 })
73 .build()
74 .expect("Galadriel reqwest client should build"),
75 }
76 }
77
78 pub fn from_env() -> Self {
82 let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
83 let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
84 Self::new(&api_key, fine_tune_api_key.as_deref())
85 }
86 fn post(&self, path: &str) -> reqwest::RequestBuilder {
87 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
88 self.http_client.post(url)
89 }
90
91 pub fn completion_model(&self, model: &str) -> CompletionModel {
103 CompletionModel::new(self.clone(), model)
104 }
105
106 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
121 AgentBuilder::new(self.completion_model(model))
122 }
123
124 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
126 &self,
127 model: &str,
128 ) -> ExtractorBuilder<T, CompletionModel> {
129 ExtractorBuilder::new(self.completion_model(model))
130 }
131}
132
133#[derive(Debug, Deserialize)]
134struct ApiErrorResponse {
135 message: String,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum ApiResponse<T> {
141 Ok(T),
142 Err(ApiErrorResponse),
143}
144
145#[derive(Clone, Debug, Deserialize)]
146pub struct Usage {
147 pub prompt_tokens: usize,
148 pub total_tokens: usize,
149}
150
151impl std::fmt::Display for Usage {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(
154 f,
155 "Prompt tokens: {} Total tokens: {}",
156 self.prompt_tokens, self.total_tokens
157 )
158 }
159}
160
161pub const O1_PREVIEW: &str = "o1-preview";
166pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
168pub const O1_MINI: &str = "o1-mini";
170pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
172pub const GPT_4O: &str = "gpt-4o";
174pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
176pub const GPT_4_TURBO: &str = "gpt-4-turbo";
178pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
180pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
182pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
184pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
186pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
188pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
190pub const GPT_4: &str = "gpt-4";
192pub const GPT_4_0613: &str = "gpt-4-0613";
194pub const GPT_4_32K: &str = "gpt-4-32k";
196pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
198pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
200pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
202pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
204pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
206
207#[derive(Debug, Deserialize)]
208pub struct CompletionResponse {
209 pub id: String,
210 pub object: String,
211 pub created: u64,
212 pub model: String,
213 pub system_fingerprint: Option<String>,
214 pub choices: Vec<Choice>,
215 pub usage: Option<Usage>,
216}
217
218impl From<ApiErrorResponse> for CompletionError {
219 fn from(err: ApiErrorResponse) -> Self {
220 CompletionError::ProviderError(err.message)
221 }
222}
223
224impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
225 type Error = CompletionError;
226
227 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
228 let Choice { message, .. } = response.choices.first().ok_or_else(|| {
229 CompletionError::ResponseError("Response contained no choices".to_owned())
230 })?;
231
232 let mut content = message
233 .content
234 .as_ref()
235 .map(|c| vec![completion::AssistantContent::text(c)])
236 .unwrap_or_default();
237
238 content.extend(message.tool_calls.iter().map(|call| {
239 completion::AssistantContent::tool_call(
240 &call.function.name,
241 &call.function.name,
242 call.function.arguments.clone(),
243 )
244 }));
245
246 let choice = OneOrMany::many(content).map_err(|_| {
247 CompletionError::ResponseError(
248 "Response contained no message or tool call (empty)".to_owned(),
249 )
250 })?;
251
252 Ok(completion::CompletionResponse {
253 choice,
254 raw_response: response,
255 })
256 }
257}
258
259#[derive(Debug, Deserialize)]
260pub struct Choice {
261 pub index: usize,
262 pub message: Message,
263 pub logprobs: Option<serde_json::Value>,
264 pub finish_reason: String,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct Message {
269 pub role: String,
270 pub content: Option<String>,
271 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
272 pub tool_calls: Vec<openai::ToolCall>,
273}
274
275impl TryFrom<Message> for message::Message {
276 type Error = message::MessageError;
277
278 fn try_from(message: Message) -> Result<Self, Self::Error> {
279 let tool_calls: Vec<message::ToolCall> = message
280 .tool_calls
281 .into_iter()
282 .map(|tool_call| tool_call.into())
283 .collect();
284
285 match message.role.as_str() {
286 "user" => Ok(Self::User {
287 content: OneOrMany::one(
288 message
289 .content
290 .map(|content| message::UserContent::text(&content))
291 .ok_or_else(|| {
292 message::MessageError::ConversionError("Empty user message".to_string())
293 })?,
294 ),
295 }),
296 "assistant" => Ok(Self::Assistant {
297 content: OneOrMany::many(
298 tool_calls
299 .into_iter()
300 .map(message::AssistantContent::ToolCall)
301 .chain(
302 message
303 .content
304 .map(|content| message::AssistantContent::text(&content))
305 .into_iter(),
306 ),
307 )
308 .map_err(|_| {
309 message::MessageError::ConversionError("Empty assistant message".to_string())
310 })?,
311 }),
312 _ => Err(message::MessageError::ConversionError(format!(
313 "Unknown role: {}",
314 message.role
315 ))),
316 }
317 }
318}
319
320impl TryFrom<message::Message> for Message {
321 type Error = message::MessageError;
322
323 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
324 match message {
325 message::Message::User { content } => Ok(Self {
326 role: "user".to_string(),
327 content: content.iter().find_map(|c| match c {
328 message::UserContent::Text(text) => Some(text.text.clone()),
329 _ => None,
330 }),
331 tool_calls: vec![],
332 }),
333 message::Message::Assistant { content } => {
334 let mut text_content: Option<String> = None;
335 let mut tool_calls = vec![];
336
337 for c in content.iter() {
338 match c {
339 message::AssistantContent::Text(text) => {
340 text_content = Some(
341 text_content
342 .map(|mut existing| {
343 existing.push('\n');
344 existing.push_str(&text.text);
345 existing
346 })
347 .unwrap_or_else(|| text.text.clone()),
348 );
349 }
350 message::AssistantContent::ToolCall(tool_call) => {
351 tool_calls.push(tool_call.clone().into());
352 }
353 }
354 }
355
356 Ok(Self {
357 role: "assistant".to_string(),
358 content: text_content,
359 tool_calls,
360 })
361 }
362 }
363 }
364}
365
366#[derive(Clone, Debug, Deserialize, Serialize)]
367pub struct ToolDefinition {
368 pub r#type: String,
369 pub function: completion::ToolDefinition,
370}
371
372impl From<completion::ToolDefinition> for ToolDefinition {
373 fn from(tool: completion::ToolDefinition) -> Self {
374 Self {
375 r#type: "function".into(),
376 function: tool,
377 }
378 }
379}
380
381#[derive(Debug, Deserialize)]
382pub struct Function {
383 pub name: String,
384 pub arguments: String,
385}
386
387#[derive(Clone)]
388pub struct CompletionModel {
389 client: Client,
390 pub model: String,
392}
393
394impl CompletionModel {
395 pub fn new(client: Client, model: &str) -> Self {
396 Self {
397 client,
398 model: model.to_string(),
399 }
400 }
401}
402
403impl completion::CompletionModel for CompletionModel {
404 type Response = CompletionResponse;
405
406 #[cfg_attr(feature = "worker", worker::send)]
407 async fn completion(
408 &self,
409 completion_request: CompletionRequest,
410 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
411 let mut full_history: Vec<Message> = match &completion_request.preamble {
413 Some(preamble) => vec![Message {
414 role: "system".to_string(),
415 content: Some(preamble.to_string()),
416 tool_calls: vec![],
417 }],
418 None => vec![],
419 };
420
421 let prompt: Message = completion_request.prompt_with_context().try_into()?;
423
424 let chat_history: Vec<Message> = completion_request
426 .chat_history
427 .into_iter()
428 .map(|message| message.try_into())
429 .collect::<Result<Vec<Message>, _>>()?;
430
431 full_history.extend(chat_history);
433 full_history.push(prompt);
434
435 let request = if completion_request.tools.is_empty() {
436 json!({
437 "model": self.model,
438 "messages": full_history,
439 "temperature": completion_request.temperature,
440 })
441 } else {
442 json!({
443 "model": self.model,
444 "messages": full_history,
445 "temperature": completion_request.temperature,
446 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
447 "tool_choice": "auto",
448 })
449 };
450
451 let response = self
452 .client
453 .post("/chat/completions")
454 .json(
455 &if let Some(params) = completion_request.additional_params {
456 json_utils::merge(request, params)
457 } else {
458 request
459 },
460 )
461 .send()
462 .await?;
463
464 if response.status().is_success() {
465 let t = response.text().await?;
466 tracing::debug!(target: "rig", "Galadriel completion error: {}", t);
467
468 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
469 ApiResponse::Ok(response) => {
470 tracing::info!(target: "rig",
471 "Galadriel completion token usage: {:?}",
472 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
473 );
474 response.try_into()
475 }
476 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
477 }
478 } else {
479 Err(CompletionError::ProviderError(response.text().await?))
480 }
481 }
482}