1use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16 OneOrMany,
17 completion::{self, CompletionError, CompletionRequest},
18 impl_conversion_traits,
19 message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::Deserialize;
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 id: None,
65 content: OneOrMany::one(AssistantContent::Text(message::Text {
66 text: raw.content,
67 })),
68 }),
69 _ => Err(CompletionError::ResponseError(format!(
70 "Unsupported message role: {}",
71 raw.role
72 ))),
73 }
74 }
75}
76
77#[derive(Debug, Deserialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80 Structured {
81 id: String,
82 object: String,
83 created: u64,
84 model: String,
85 choices: Vec<ChatChoice>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 usage: Option<Usage>,
88 },
89 Simple(String),
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ChatChoice {
94 pub message: RawMessage,
95 #[serde(default)]
96 pub finish_reason: Option<String>,
97 #[serde(default)]
98 pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize)]
102struct ModelsResponse {
103 data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize)]
107struct ModelInfo {
108 id: String,
109}
110
111#[derive(Clone)]
112pub struct Client {
114 base_url: String,
115 http_client: reqwest::Client,
116 api_key: String,
117 headers: HeaderMap,
118}
119
120impl std::fmt::Debug for Client {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("Client")
123 .field("base_url", &self.base_url)
124 .field("http_client", &self.http_client)
125 .field("api_key", &"<REDACTED>")
126 .field("headers", &self.headers)
127 .finish()
128 }
129}
130
131impl Client {
132 pub fn new(api_key: &str) -> Result<Self, MiraError> {
134 let mut headers = HeaderMap::new();
135 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136 headers.insert(
137 reqwest::header::ACCEPT,
138 HeaderValue::from_static("application/json"),
139 );
140 headers.insert(
141 reqwest::header::USER_AGENT,
142 HeaderValue::from_static("rig-client/1.0"),
143 );
144
145 Ok(Self {
146 base_url: MIRA_API_BASE_URL.to_string(),
147 api_key: api_key.to_string(),
148 http_client: reqwest::Client::builder()
149 .build()
150 .expect("Failed to build HTTP client"),
151 headers,
152 })
153 }
154
155 pub fn new_with_base_url(
157 api_key: &str,
158 base_url: impl Into<String>,
159 ) -> Result<Self, MiraError> {
160 let mut client = Self::new(api_key)?;
161 client.base_url = base_url.into();
162 Ok(client)
163 }
164
165 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
168 self.http_client = client;
169
170 self
171 }
172
173 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
175 let url = format!("{}/v1/models", self.base_url);
176
177 let response = self
178 .http_client
179 .get(&url)
180 .bearer_auth(&self.api_key)
181 .headers(self.headers.clone())
182 .send()
183 .await?;
184
185 let status = response.status();
186
187 if !status.is_success() {
188 let _error_text = response.text().await.unwrap_or_default();
190 tracing::error!("Error response: {}", _error_text);
191 return Err(MiraError::ApiError(status.as_u16()));
192 }
193
194 let response_text = response.text().await?;
195
196 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
197 tracing::error!("Failed to parse response: {}", e);
198 MiraError::JsonError(e)
199 })?;
200
201 Ok(models.data.into_iter().map(|model| model.id).collect())
202 }
203}
204
205impl ProviderClient for Client {
206 fn from_env() -> Self {
209 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
210 Self::new(&api_key).expect("Could not create Mira Client")
211 }
212
213 fn from_val(input: crate::client::ProviderValue) -> Self {
214 let crate::client::ProviderValue::Simple(api_key) = input else {
215 panic!("Incorrect provider value type")
216 };
217 Self::new(&api_key).unwrap()
218 }
219}
220
221impl CompletionClient for Client {
222 type CompletionModel = CompletionModel;
223 fn completion_model(&self, model: &str) -> CompletionModel {
225 CompletionModel::new(self.to_owned(), model)
226 }
227}
228
229impl_conversion_traits!(
230 AsEmbeddings,
231 AsTranscription,
232 AsImageGeneration,
233 AsAudioGeneration for Client
234);
235
236#[derive(Clone)]
237pub struct CompletionModel {
238 client: Client,
239 pub model: String,
241}
242
243impl CompletionModel {
244 pub fn new(client: Client, model: &str) -> Self {
245 Self {
246 client,
247 model: model.to_string(),
248 }
249 }
250
251 fn create_completion_request(
252 &self,
253 completion_request: CompletionRequest,
254 ) -> Result<Value, CompletionError> {
255 let mut messages = Vec::new();
256
257 if let Some(preamble) = &completion_request.preamble {
259 messages.push(serde_json::json!({
260 "role": "user",
261 "content": preamble.to_string()
262 }));
263 }
264
265 if let Some(Message::User { content }) = completion_request.normalized_documents() {
267 let text = content
268 .into_iter()
269 .filter_map(|doc| match doc {
270 UserContent::Document(doc) => Some(doc.data),
271 UserContent::Text(text) => Some(text.text),
272
273 _ => None,
275 })
276 .collect::<Vec<_>>()
277 .join("\n");
278
279 messages.push(serde_json::json!({
280 "role": "user",
281 "content": text
282 }));
283 }
284
285 for msg in completion_request.chat_history {
287 let (role, content) = match msg {
288 Message::User { content } => {
289 let text = content
290 .iter()
291 .map(|c| match c {
292 UserContent::Text(text) => &text.text,
293 _ => "",
294 })
295 .collect::<Vec<_>>()
296 .join("\n");
297 ("user", text)
298 }
299 Message::Assistant { content, .. } => {
300 let text = content
301 .iter()
302 .map(|c| match c {
303 AssistantContent::Text(text) => &text.text,
304 _ => "",
305 })
306 .collect::<Vec<_>>()
307 .join("\n");
308 ("assistant", text)
309 }
310 };
311 messages.push(serde_json::json!({
312 "role": role,
313 "content": content
314 }));
315 }
316
317 let request = serde_json::json!({
318 "model": self.model,
319 "messages": messages,
320 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
321 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
322 "stream": false
323 });
324
325 Ok(request)
326 }
327}
328
329impl completion::CompletionModel for CompletionModel {
330 type Response = CompletionResponse;
331 type StreamingResponse = openai::StreamingCompletionResponse;
332
333 #[cfg_attr(feature = "worker", worker::send)]
334 async fn completion(
335 &self,
336 completion_request: CompletionRequest,
337 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
338 if !completion_request.tools.is_empty() {
339 tracing::warn!(target: "rig",
340 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
341 completion_request.tools.len()
342 );
343 }
344
345 let mira_request = self.create_completion_request(completion_request)?;
346
347 let response = self
348 .client
349 .http_client
350 .post(format!("{}/v1/chat/completions", self.client.base_url))
351 .bearer_auth(&self.client.api_key)
352 .headers(self.client.headers.clone())
353 .json(&mira_request)
354 .send()
355 .await
356 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
357
358 if !response.status().is_success() {
359 let status = response.status().as_u16();
360 let error_text = response.text().await.unwrap_or_default();
361 return Err(CompletionError::ProviderError(format!(
362 "API error: {status} - {error_text}"
363 )));
364 }
365
366 let response: CompletionResponse = response
367 .json()
368 .await
369 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
370
371 response.try_into()
372 }
373
374 #[cfg_attr(feature = "worker", worker::send)]
375 async fn stream(
376 &self,
377 completion_request: CompletionRequest,
378 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
379 let mut request = self.create_completion_request(completion_request)?;
380
381 request = merge(request, json!({"stream": true}));
382
383 let builder = self
384 .client
385 .http_client
386 .post(format!("{}/v1/chat/completions", self.client.base_url))
387 .headers(self.client.headers.clone())
388 .json(&request);
389
390 send_compatible_streaming_request(builder).await
391 }
392}
393
394impl From<ApiErrorResponse> for CompletionError {
395 fn from(err: ApiErrorResponse) -> Self {
396 CompletionError::ProviderError(err.message)
397 }
398}
399
400impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
401 type Error = CompletionError;
402
403 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
404 let (content, usage) = match &response {
405 CompletionResponse::Structured { choices, usage, .. } => {
406 let choice = choices.first().ok_or_else(|| {
407 CompletionError::ResponseError("Response contained no choices".to_owned())
408 })?;
409
410 let usage = usage
411 .as_ref()
412 .map(|usage| completion::Usage {
413 input_tokens: usage.prompt_tokens as u64,
414 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
415 total_tokens: usage.total_tokens as u64,
416 })
417 .unwrap_or_default();
418
419 let message = message::Message::try_from(choice.message.clone())?;
421
422 let content = match message {
423 Message::Assistant { content, .. } => {
424 if content.is_empty() {
425 return Err(CompletionError::ResponseError(
426 "Response contained empty content".to_owned(),
427 ));
428 }
429
430 for c in content.iter() {
432 if !matches!(c, AssistantContent::Text(_)) {
433 tracing::warn!(target: "rig",
434 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
435 );
436 }
437 }
438
439 content.iter().map(|c| {
440 match c {
441 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
442 other => Err(CompletionError::ResponseError(
443 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
444 ))
445 }
446 }).collect::<Result<Vec<_>, _>>()?
447 }
448 Message::User { .. } => {
449 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
450 return Err(CompletionError::ResponseError(
451 "Received user message in response where assistant message was expected".to_owned()
452 ));
453 }
454 };
455
456 (content, usage)
457 }
458 CompletionResponse::Simple(text) => (
459 vec![completion::AssistantContent::text(text)],
460 completion::Usage::new(),
461 ),
462 };
463
464 let choice = OneOrMany::many(content).map_err(|_| {
465 CompletionError::ResponseError(
466 "Response contained no message or tool call (empty)".to_owned(),
467 )
468 })?;
469
470 Ok(completion::CompletionResponse {
471 choice,
472 usage,
473 raw_response: response,
474 })
475 }
476}
477
478#[derive(Clone, Debug, Deserialize)]
479pub struct Usage {
480 pub prompt_tokens: usize,
481 pub total_tokens: usize,
482}
483
484impl std::fmt::Display for Usage {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 write!(
487 f,
488 "Prompt tokens: {} Total tokens: {}",
489 self.prompt_tokens, self.total_tokens
490 )
491 }
492}
493
494impl From<Message> for serde_json::Value {
495 fn from(msg: Message) -> Self {
496 match msg {
497 Message::User { content } => {
498 let text = content
499 .iter()
500 .map(|c| match c {
501 UserContent::Text(text) => &text.text,
502 _ => "",
503 })
504 .collect::<Vec<_>>()
505 .join("\n");
506 serde_json::json!({
507 "role": "user",
508 "content": text
509 })
510 }
511 Message::Assistant { content, .. } => {
512 let text = content
513 .iter()
514 .map(|c| match c {
515 AssistantContent::Text(text) => &text.text,
516 _ => "",
517 })
518 .collect::<Vec<_>>()
519 .join("\n");
520 serde_json::json!({
521 "role": "assistant",
522 "content": text
523 })
524 }
525 }
526 }
527}
528
529impl TryFrom<serde_json::Value> for Message {
530 type Error = CompletionError;
531
532 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
533 let role = value["role"].as_str().ok_or_else(|| {
534 CompletionError::ResponseError("Message missing role field".to_owned())
535 })?;
536
537 let content = match value.get("content") {
539 Some(content) => match content {
540 serde_json::Value::String(s) => s.clone(),
541 serde_json::Value::Array(arr) => arr
542 .iter()
543 .filter_map(|c| {
544 c.get("text")
545 .and_then(|t| t.as_str())
546 .map(|text| text.to_string())
547 })
548 .collect::<Vec<_>>()
549 .join("\n"),
550 _ => {
551 return Err(CompletionError::ResponseError(
552 "Message content must be string or array".to_owned(),
553 ));
554 }
555 },
556 None => {
557 return Err(CompletionError::ResponseError(
558 "Message missing content field".to_owned(),
559 ));
560 }
561 };
562
563 match role {
564 "user" => Ok(Message::User {
565 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
566 }),
567 "assistant" => Ok(Message::Assistant {
568 id: None,
569 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
570 }),
571 _ => Err(CompletionError::ResponseError(format!(
572 "Unsupported message role: {role}"
573 ))),
574 }
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::message::UserContent;
582 use serde_json::json;
583
584 #[test]
585 fn test_deserialize_message() {
586 let assistant_message_json = json!({
588 "role": "assistant",
589 "content": "Hello there, how may I assist you today?"
590 });
591
592 let user_message_json = json!({
593 "role": "user",
594 "content": "What can you help me with?"
595 });
596
597 let assistant_message_array_json = json!({
599 "role": "assistant",
600 "content": [{
601 "type": "text",
602 "text": "Hello there, how may I assist you today?"
603 }]
604 });
605
606 let assistant_message = Message::try_from(assistant_message_json).unwrap();
607 let user_message = Message::try_from(user_message_json).unwrap();
608 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
609
610 match assistant_message {
612 Message::Assistant { content, .. } => {
613 assert_eq!(
614 content.first(),
615 AssistantContent::Text(message::Text {
616 text: "Hello there, how may I assist you today?".to_string()
617 })
618 );
619 }
620 _ => panic!("Expected assistant message"),
621 }
622
623 match user_message {
624 Message::User { content } => {
625 assert_eq!(
626 content.first(),
627 UserContent::Text(message::Text {
628 text: "What can you help me with?".to_string()
629 })
630 );
631 }
632 _ => panic!("Expected user message"),
633 }
634
635 match assistant_message_array {
637 Message::Assistant { content, .. } => {
638 assert_eq!(
639 content.first(),
640 AssistantContent::Text(message::Text {
641 text: "Hello there, how may I assist you today?".to_string()
642 })
643 );
644 }
645 _ => panic!("Expected assistant message"),
646 }
647 }
648
649 #[test]
650 fn test_message_conversion() {
651 let original_message = message::Message::User {
653 content: OneOrMany::one(message::UserContent::text("Hello")),
654 };
655
656 let mira_value: serde_json::Value = original_message.clone().into();
658
659 let converted_message: Message = mira_value.try_into().unwrap();
661
662 assert_eq!(original_message, converted_message);
663 }
664
665 #[test]
666 fn test_completion_response_conversion() {
667 let mira_response = CompletionResponse::Structured {
668 id: "resp_123".to_string(),
669 object: "chat.completion".to_string(),
670 created: 1234567890,
671 model: "deepseek-r1".to_string(),
672 choices: vec![ChatChoice {
673 message: RawMessage {
674 role: "assistant".to_string(),
675 content: "Test response".to_string(),
676 },
677 finish_reason: Some("stop".to_string()),
678 index: Some(0),
679 }],
680 usage: Some(Usage {
681 prompt_tokens: 10,
682 total_tokens: 20,
683 }),
684 };
685
686 let completion_response: completion::CompletionResponse<CompletionResponse> =
687 mira_response.try_into().unwrap();
688
689 assert_eq!(
690 completion_response.choice.first(),
691 completion::AssistantContent::text("Test response")
692 );
693 }
694}