1use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16 completion::{self, CompletionError, CompletionRequest},
17 impl_conversion_traits,
18 message::{self, AssistantContent, Message, UserContent},
19 OneOrMany,
20};
21use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
22use serde::Deserialize;
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 content: OneOrMany::one(AssistantContent::Text(message::Text {
65 text: raw.content,
66 })),
67 }),
68 _ => Err(CompletionError::ResponseError(format!(
69 "Unsupported message role: {}",
70 raw.role
71 ))),
72 }
73 }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79 Structured {
80 id: String,
81 object: String,
82 created: u64,
83 model: String,
84 choices: Vec<ChatChoice>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 usage: Option<Usage>,
87 },
88 Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93 pub message: RawMessage,
94 #[serde(default)]
95 pub finish_reason: Option<String>,
96 #[serde(default)]
97 pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102 data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107 id: String,
108}
109
110#[derive(Clone, Debug)]
111pub struct Client {
113 base_url: String,
114 client: reqwest::Client,
115 headers: HeaderMap,
116}
117
118impl Client {
119 pub fn new(api_key: &str) -> Result<Self, MiraError> {
121 let mut headers = HeaderMap::new();
122 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123 headers.insert(
124 AUTHORIZATION,
125 HeaderValue::from_str(&format!("Bearer {api_key}"))
126 .map_err(|_| MiraError::InvalidApiKey)?,
127 );
128 headers.insert(
129 reqwest::header::ACCEPT,
130 HeaderValue::from_static("application/json"),
131 );
132 headers.insert(
133 reqwest::header::USER_AGENT,
134 HeaderValue::from_static("rig-client/1.0"),
135 );
136
137 Ok(Self {
138 base_url: MIRA_API_BASE_URL.to_string(),
139 client: reqwest::Client::builder()
140 .build()
141 .expect("Failed to build HTTP client"),
142 headers,
143 })
144 }
145
146 pub fn new_with_base_url(
148 api_key: &str,
149 base_url: impl Into<String>,
150 ) -> Result<Self, MiraError> {
151 let mut client = Self::new(api_key)?;
152 client.base_url = base_url.into();
153 Ok(client)
154 }
155
156 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
158 let url = format!("{}/v1/models", self.base_url);
159
160 let response = self
161 .client
162 .get(&url)
163 .headers(self.headers.clone())
164 .send()
165 .await?;
166
167 let status = response.status();
168
169 if !status.is_success() {
170 let _error_text = response.text().await.unwrap_or_default();
172 tracing::error!("Error response: {}", _error_text);
173 return Err(MiraError::ApiError(status.as_u16()));
174 }
175
176 let response_text = response.text().await?;
177
178 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
179 tracing::error!("Failed to parse response: {}", e);
180 MiraError::JsonError(e)
181 })?;
182
183 Ok(models.data.into_iter().map(|model| model.id).collect())
184 }
185}
186
187impl ProviderClient for Client {
188 fn from_env() -> Self {
191 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
192 Self::new(&api_key).expect("Could not create Mira Client")
193 }
194}
195
196impl CompletionClient for Client {
197 type CompletionModel = CompletionModel;
198 fn completion_model(&self, model: &str) -> CompletionModel {
200 CompletionModel::new(self.to_owned(), model)
201 }
202}
203
204impl_conversion_traits!(
205 AsEmbeddings,
206 AsTranscription,
207 AsImageGeneration,
208 AsAudioGeneration for Client
209);
210
211#[derive(Clone)]
212pub struct CompletionModel {
213 client: Client,
214 pub model: String,
216}
217
218impl CompletionModel {
219 pub fn new(client: Client, model: &str) -> Self {
220 Self {
221 client,
222 model: model.to_string(),
223 }
224 }
225
226 fn create_completion_request(
227 &self,
228 completion_request: CompletionRequest,
229 ) -> Result<Value, CompletionError> {
230 let mut messages = Vec::new();
231
232 if let Some(preamble) = &completion_request.preamble {
234 messages.push(serde_json::json!({
235 "role": "user",
236 "content": preamble.to_string()
237 }));
238 }
239
240 if let Some(Message::User { content }) = completion_request.normalized_documents() {
242 let text = content
243 .into_iter()
244 .filter_map(|doc| match doc {
245 UserContent::Document(doc) => Some(doc.data),
246 UserContent::Text(text) => Some(text.text),
247
248 _ => None,
250 })
251 .collect::<Vec<_>>()
252 .join("\n");
253
254 messages.push(serde_json::json!({
255 "role": "user",
256 "content": text
257 }));
258 }
259
260 for msg in completion_request.chat_history {
262 let (role, content) = match msg {
263 Message::User { content } => {
264 let text = content
265 .iter()
266 .map(|c| match c {
267 UserContent::Text(text) => &text.text,
268 _ => "",
269 })
270 .collect::<Vec<_>>()
271 .join("\n");
272 ("user", text)
273 }
274 Message::Assistant { content } => {
275 let text = content
276 .iter()
277 .map(|c| match c {
278 AssistantContent::Text(text) => &text.text,
279 _ => "",
280 })
281 .collect::<Vec<_>>()
282 .join("\n");
283 ("assistant", text)
284 }
285 };
286 messages.push(serde_json::json!({
287 "role": role,
288 "content": content
289 }));
290 }
291
292 let request = serde_json::json!({
293 "model": self.model,
294 "messages": messages,
295 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
296 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
297 "stream": false
298 });
299
300 Ok(request)
301 }
302}
303
304impl completion::CompletionModel for CompletionModel {
305 type Response = CompletionResponse;
306 type StreamingResponse = openai::StreamingCompletionResponse;
307
308 #[cfg_attr(feature = "worker", worker::send)]
309 async fn completion(
310 &self,
311 completion_request: CompletionRequest,
312 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
313 if !completion_request.tools.is_empty() {
314 tracing::warn!(target: "rig",
315 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
316 completion_request.tools.len()
317 );
318 }
319
320 let mira_request = self.create_completion_request(completion_request)?;
321
322 let response = self
323 .client
324 .client
325 .post(format!("{}/v1/chat/completions", self.client.base_url))
326 .headers(self.client.headers.clone())
327 .json(&mira_request)
328 .send()
329 .await
330 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
331
332 if !response.status().is_success() {
333 let status = response.status().as_u16();
334 let error_text = response.text().await.unwrap_or_default();
335 return Err(CompletionError::ProviderError(format!(
336 "API error: {status} - {error_text}"
337 )));
338 }
339
340 let response: CompletionResponse = response
341 .json()
342 .await
343 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
344
345 response.try_into()
346 }
347
348 #[cfg_attr(feature = "worker", worker::send)]
349 async fn stream(
350 &self,
351 completion_request: CompletionRequest,
352 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
353 let mut request = self.create_completion_request(completion_request)?;
354
355 request = merge(request, json!({"stream": true}));
356
357 let builder = self
358 .client
359 .client
360 .post(format!("{}/v1/chat/completions", self.client.base_url))
361 .headers(self.client.headers.clone())
362 .json(&request);
363
364 send_compatible_streaming_request(builder).await
365 }
366}
367
368impl From<ApiErrorResponse> for CompletionError {
369 fn from(err: ApiErrorResponse) -> Self {
370 CompletionError::ProviderError(err.message)
371 }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375 type Error = CompletionError;
376
377 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378 let content = match &response {
379 CompletionResponse::Structured { choices, .. } => {
380 let choice = choices.first().ok_or_else(|| {
381 CompletionError::ResponseError("Response contained no choices".to_owned())
382 })?;
383
384 let message = message::Message::try_from(choice.message.clone())?;
386
387 match message {
388 Message::Assistant { content } => {
389 if content.is_empty() {
390 return Err(CompletionError::ResponseError(
391 "Response contained empty content".to_owned(),
392 ));
393 }
394
395 for c in content.iter() {
397 if !matches!(c, AssistantContent::Text(_)) {
398 tracing::warn!(target: "rig",
399 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
400 );
401 }
402 }
403
404 content.iter().map(|c| {
405 match c {
406 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
407 other => Err(CompletionError::ResponseError(
408 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
409 ))
410 }
411 }).collect::<Result<Vec<_>, _>>()?
412 }
413 Message::User { .. } => {
414 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
415 return Err(CompletionError::ResponseError(
416 "Received user message in response where assistant message was expected".to_owned()
417 ));
418 }
419 }
420 }
421 CompletionResponse::Simple(text) => {
422 vec![completion::AssistantContent::text(text)]
423 }
424 };
425
426 let choice = OneOrMany::many(content).map_err(|_| {
427 CompletionError::ResponseError(
428 "Response contained no message or tool call (empty)".to_owned(),
429 )
430 })?;
431
432 Ok(completion::CompletionResponse {
433 choice,
434 raw_response: response,
435 })
436 }
437}
438
439#[derive(Clone, Debug, Deserialize)]
440pub struct Usage {
441 pub prompt_tokens: usize,
442 pub total_tokens: usize,
443}
444
445impl std::fmt::Display for Usage {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 write!(
448 f,
449 "Prompt tokens: {} Total tokens: {}",
450 self.prompt_tokens, self.total_tokens
451 )
452 }
453}
454
455impl From<Message> for serde_json::Value {
456 fn from(msg: Message) -> Self {
457 match msg {
458 Message::User { content } => {
459 let text = content
460 .iter()
461 .map(|c| match c {
462 UserContent::Text(text) => &text.text,
463 _ => "",
464 })
465 .collect::<Vec<_>>()
466 .join("\n");
467 serde_json::json!({
468 "role": "user",
469 "content": text
470 })
471 }
472 Message::Assistant { content } => {
473 let text = content
474 .iter()
475 .map(|c| match c {
476 AssistantContent::Text(text) => &text.text,
477 _ => "",
478 })
479 .collect::<Vec<_>>()
480 .join("\n");
481 serde_json::json!({
482 "role": "assistant",
483 "content": text
484 })
485 }
486 }
487 }
488}
489
490impl TryFrom<serde_json::Value> for Message {
491 type Error = CompletionError;
492
493 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
494 let role = value["role"].as_str().ok_or_else(|| {
495 CompletionError::ResponseError("Message missing role field".to_owned())
496 })?;
497
498 let content = match value.get("content") {
500 Some(content) => match content {
501 serde_json::Value::String(s) => s.clone(),
502 serde_json::Value::Array(arr) => arr
503 .iter()
504 .filter_map(|c| {
505 c.get("text")
506 .and_then(|t| t.as_str())
507 .map(|text| text.to_string())
508 })
509 .collect::<Vec<_>>()
510 .join("\n"),
511 _ => {
512 return Err(CompletionError::ResponseError(
513 "Message content must be string or array".to_owned(),
514 ))
515 }
516 },
517 None => {
518 return Err(CompletionError::ResponseError(
519 "Message missing content field".to_owned(),
520 ))
521 }
522 };
523
524 match role {
525 "user" => Ok(Message::User {
526 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
527 }),
528 "assistant" => Ok(Message::Assistant {
529 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
530 }),
531 _ => Err(CompletionError::ResponseError(format!(
532 "Unsupported message role: {role}"
533 ))),
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::message::UserContent;
542 use serde_json::json;
543
544 #[test]
545 fn test_deserialize_message() {
546 let assistant_message_json = json!({
548 "role": "assistant",
549 "content": "Hello there, how may I assist you today?"
550 });
551
552 let user_message_json = json!({
553 "role": "user",
554 "content": "What can you help me with?"
555 });
556
557 let assistant_message_array_json = json!({
559 "role": "assistant",
560 "content": [{
561 "type": "text",
562 "text": "Hello there, how may I assist you today?"
563 }]
564 });
565
566 let assistant_message = Message::try_from(assistant_message_json).unwrap();
567 let user_message = Message::try_from(user_message_json).unwrap();
568 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
569
570 match assistant_message {
572 Message::Assistant { content } => {
573 assert_eq!(
574 content.first(),
575 AssistantContent::Text(message::Text {
576 text: "Hello there, how may I assist you today?".to_string()
577 })
578 );
579 }
580 _ => panic!("Expected assistant message"),
581 }
582
583 match user_message {
584 Message::User { content } => {
585 assert_eq!(
586 content.first(),
587 UserContent::Text(message::Text {
588 text: "What can you help me with?".to_string()
589 })
590 );
591 }
592 _ => panic!("Expected user message"),
593 }
594
595 match assistant_message_array {
597 Message::Assistant { content } => {
598 assert_eq!(
599 content.first(),
600 AssistantContent::Text(message::Text {
601 text: "Hello there, how may I assist you today?".to_string()
602 })
603 );
604 }
605 _ => panic!("Expected assistant message"),
606 }
607 }
608
609 #[test]
610 fn test_message_conversion() {
611 let original_message = message::Message::User {
613 content: OneOrMany::one(message::UserContent::text("Hello")),
614 };
615
616 let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
618
619 let converted_message: Message = mira_value.try_into().unwrap();
621
622 let final_message: message::Message = converted_message.try_into().unwrap();
624
625 assert_eq!(original_message, final_message);
626 }
627
628 #[test]
629 fn test_completion_response_conversion() {
630 let mira_response = CompletionResponse::Structured {
631 id: "resp_123".to_string(),
632 object: "chat.completion".to_string(),
633 created: 1234567890,
634 model: "deepseek-r1".to_string(),
635 choices: vec![ChatChoice {
636 message: RawMessage {
637 role: "assistant".to_string(),
638 content: "Test response".to_string(),
639 },
640 finish_reason: Some("stop".to_string()),
641 index: Some(0),
642 }],
643 usage: Some(Usage {
644 prompt_tokens: 10,
645 total_tokens: 20,
646 }),
647 };
648
649 let completion_response: completion::CompletionResponse<CompletionResponse> =
650 mira_response.try_into().unwrap();
651
652 assert_eq!(
653 completion_response.choice.first(),
654 completion::AssistantContent::text("Test response")
655 );
656 }
657}