1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58 type Extension<H>
59 = MiraExt
60 where
61 H: HttpClientExt;
62 type ApiKey = MiraApiKey;
63
64 const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66 fn build<H>(
67 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68 ) -> http_client::Result<Self::Extension<H>>
69 where
70 H: HttpClientExt,
71 {
72 Ok(MiraExt)
73 }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
78
79#[derive(Debug, Error)]
80pub enum MiraError {
81 #[error("Invalid API key")]
82 InvalidApiKey,
83 #[error("API error: {0}")]
84 ApiError(u16),
85 #[error("Request error: {0}")]
86 RequestError(#[from] http_client::Error),
87 #[error("UTF-8 error: {0}")]
88 Utf8Error(#[from] FromUtf8Error),
89 #[error("JSON error: {0}")]
90 JsonError(#[from] serde_json::Error),
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 message: String,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct RawMessage {
100 pub role: String,
101 pub content: String,
102}
103
104const MIRA_API_BASE_URL: &str = "https://api.mira.network";
105
106impl TryFrom<RawMessage> for message::Message {
107 type Error = CompletionError;
108
109 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
110 match raw.role.as_str() {
111 "system" => Ok(message::Message::System {
112 content: raw.content,
113 }),
114 "user" => Ok(message::Message::User {
115 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
116 }),
117 "assistant" => Ok(message::Message::Assistant {
118 id: None,
119 content: OneOrMany::one(AssistantContent::Text(message::Text {
120 text: raw.content,
121 })),
122 }),
123 _ => Err(CompletionError::ResponseError(format!(
124 "Unsupported message role: {}",
125 raw.role
126 ))),
127 }
128 }
129}
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(untagged)]
133pub enum CompletionResponse {
134 Structured {
135 id: String,
136 object: String,
137 created: u64,
138 model: String,
139 choices: Vec<ChatChoice>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 usage: Option<Usage>,
142 },
143 Simple(String),
144}
145
146#[derive(Debug, Deserialize, Serialize)]
147pub struct ChatChoice {
148 pub message: RawMessage,
149 #[serde(default)]
150 pub finish_reason: Option<String>,
151 #[serde(default)]
152 pub index: Option<usize>,
153}
154
155#[derive(Debug, Deserialize, Serialize)]
156struct ModelsResponse {
157 data: Vec<ModelInfo>,
158}
159
160#[derive(Debug, Deserialize, Serialize)]
161struct ModelInfo {
162 id: String,
163}
164
165impl<T> Client<T>
166where
167 T: HttpClientExt + 'static,
168{
169 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
171 let req = self.get("/v1/models").and_then(|req| {
172 req.body(http_client::NoBody)
173 .map_err(http_client::Error::Protocol)
174 })?;
175
176 let response = self.send(req).await?;
177
178 let status = response.status();
179
180 if !status.is_success() {
181 let error_text = http_client::text(response).await.unwrap_or_default();
183 tracing::error!("Error response: {}", error_text);
184 return Err(MiraError::ApiError(status.as_u16()));
185 }
186
187 let response_text = http_client::text(response).await?;
188
189 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
190 tracing::error!("Failed to parse response: {}", e);
191 MiraError::JsonError(e)
192 })?;
193
194 Ok(models.data.into_iter().map(|model| model.id).collect())
195 }
196}
197
198impl ProviderClient for Client {
199 type Input = String;
200
201 fn from_env() -> Self {
204 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
205 Self::new(&api_key).unwrap()
206 }
207
208 fn from_val(input: Self::Input) -> Self {
209 Self::new(&input).unwrap()
210 }
211}
212
213#[derive(Debug, Serialize, Deserialize)]
214pub(super) struct MiraCompletionRequest {
215 model: String,
216 pub messages: Vec<RawMessage>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 temperature: Option<f64>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 max_tokens: Option<u64>,
221 pub stream: bool,
222}
223
224impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
225 type Error = CompletionError;
226
227 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
228 if req.output_schema.is_some() {
229 tracing::warn!("Structured outputs currently not supported for Mira");
230 }
231 let model = req.model.clone().unwrap_or_else(|| model.to_string());
232 let mut messages = Vec::new();
233
234 if let Some(content) = &req.preamble {
235 messages.push(RawMessage {
236 role: "user".to_string(),
237 content: content.to_string(),
238 });
239 }
240
241 if let Some(Message::User { content }) = req.normalized_documents() {
242 let text = content
243 .into_iter()
244 .filter_map(|doc| match doc {
245 UserContent::Document(Document {
246 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
247 ..
248 }) => Some(data),
249 UserContent::Text(text) => Some(text.text),
250
251 _ => None,
253 })
254 .collect::<Vec<_>>()
255 .join("\n");
256
257 messages.push(RawMessage {
258 role: "user".to_string(),
259 content: text,
260 });
261 }
262
263 for msg in req.chat_history {
264 let (role, content) = match msg {
265 Message::System { content } => ("system", content),
266 Message::User { content } => {
267 let text = content
268 .iter()
269 .map(|c| match c {
270 UserContent::Text(text) => &text.text,
271 _ => "",
272 })
273 .collect::<Vec<_>>()
274 .join("\n");
275 ("user", text)
276 }
277 Message::Assistant { content, .. } => {
278 let text = content
279 .iter()
280 .map(|c| match c {
281 AssistantContent::Text(text) => &text.text,
282 _ => "",
283 })
284 .collect::<Vec<_>>()
285 .join("\n");
286 ("assistant", text)
287 }
288 };
289 messages.push(RawMessage {
290 role: role.to_string(),
291 content,
292 });
293 }
294
295 Ok(Self {
296 model: model.to_string(),
297 messages,
298 temperature: req.temperature,
299 max_tokens: req.max_tokens,
300 stream: false,
301 })
302 }
303}
304
305#[derive(Clone)]
306pub struct CompletionModel<T = reqwest::Client> {
307 client: Client<T>,
308 pub model: String,
310}
311
312impl<T> CompletionModel<T> {
313 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
314 Self {
315 client,
316 model: model.into(),
317 }
318 }
319}
320
321impl<T> completion::CompletionModel for CompletionModel<T>
322where
323 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
324{
325 type Response = CompletionResponse;
326 type StreamingResponse = openai::StreamingCompletionResponse;
327
328 type Client = Client<T>;
329
330 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
331 Self::new(client.clone(), model)
332 }
333
334 async fn completion(
335 &self,
336 completion_request: CompletionRequest,
337 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
338 let span = if tracing::Span::current().is_disabled() {
339 info_span!(
340 target: "rig::completions",
341 "chat",
342 gen_ai.operation.name = "chat",
343 gen_ai.provider.name = "mira",
344 gen_ai.request.model = self.model,
345 gen_ai.system_instructions = tracing::field::Empty,
346 gen_ai.response.id = tracing::field::Empty,
347 gen_ai.response.model = tracing::field::Empty,
348 gen_ai.usage.output_tokens = tracing::field::Empty,
349 gen_ai.usage.input_tokens = tracing::field::Empty,
350 gen_ai.usage.cached_tokens = tracing::field::Empty,
351 )
352 } else {
353 tracing::Span::current()
354 };
355
356 span.record("gen_ai.system_instructions", &completion_request.preamble);
357
358 if !completion_request.tools.is_empty() {
359 tracing::warn!(target: "rig::completions",
360 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
361 len = completion_request.tools.len()
362 );
363 }
364
365 if completion_request.tool_choice.is_some() {
366 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
367 }
368
369 if completion_request.additional_params.is_some() {
370 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
371 }
372
373 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
374
375 if tracing::enabled!(tracing::Level::TRACE) {
376 tracing::trace!(target: "rig::completions",
377 "Mira completion request: {}",
378 serde_json::to_string_pretty(&request)?
379 );
380 }
381
382 let body = serde_json::to_vec(&request)?;
383
384 let req = self
385 .client
386 .post("/v1/chat/completions")?
387 .body(body)
388 .map_err(http_client::Error::from)?;
389
390 let async_block = async move {
391 let response = self
392 .client
393 .send::<_, bytes::Bytes>(req)
394 .await
395 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
396
397 let status = response.status();
398 let response_body = response.into_body().into_future().await?.to_vec();
399
400 if !status.is_success() {
401 let status = status.as_u16();
402 let error_text = String::from_utf8_lossy(&response_body).to_string();
403 return Err(CompletionError::ProviderError(format!(
404 "API error: {status} - {error_text}"
405 )));
406 }
407
408 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
409
410 if tracing::enabled!(tracing::Level::TRACE) {
411 tracing::trace!(target: "rig::completions",
412 "Mira completion response: {}",
413 serde_json::to_string_pretty(&response)?
414 );
415 }
416
417 if let CompletionResponse::Structured {
418 id, model, usage, ..
419 } = &response
420 {
421 let span = tracing::Span::current();
422 span.record("gen_ai.response.model_name", model);
423 span.record("gen_ai.response.id", id);
424 if let Some(usage) = usage {
425 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
426 span.record(
427 "gen_ai.usage.output_tokens",
428 usage.total_tokens - usage.prompt_tokens,
429 );
430 }
431 }
432
433 response.try_into()
434 };
435
436 async_block.instrument(span).await
437 }
438
439 async fn stream(
440 &self,
441 completion_request: CompletionRequest,
442 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
443 let span = if tracing::Span::current().is_disabled() {
444 info_span!(
445 target: "rig::completions",
446 "chat_streaming",
447 gen_ai.operation.name = "chat_streaming",
448 gen_ai.provider.name = "mira",
449 gen_ai.request.model = self.model,
450 gen_ai.system_instructions = tracing::field::Empty,
451 gen_ai.response.id = tracing::field::Empty,
452 gen_ai.response.model = tracing::field::Empty,
453 gen_ai.usage.output_tokens = tracing::field::Empty,
454 gen_ai.usage.input_tokens = tracing::field::Empty,
455 gen_ai.usage.cached_tokens = tracing::field::Empty,
456 )
457 } else {
458 tracing::Span::current()
459 };
460
461 span.record("gen_ai.system_instructions", &completion_request.preamble);
462
463 if !completion_request.tools.is_empty() {
464 tracing::warn!(target: "rig::completions",
465 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
466 len = completion_request.tools.len()
467 );
468 }
469
470 if completion_request.tool_choice.is_some() {
471 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
472 }
473
474 if completion_request.additional_params.is_some() {
475 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
476 }
477 let mut request =
478 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
479 request.stream = true;
480
481 if tracing::enabled!(tracing::Level::TRACE) {
482 tracing::trace!(target: "rig::completions",
483 "Mira completion request: {}",
484 serde_json::to_string_pretty(&request)?
485 );
486 }
487
488 let body = serde_json::to_vec(&request)?;
489
490 let req = self
491 .client
492 .post("/v1/chat/completions")?
493 .body(body)
494 .map_err(http_client::Error::from)?;
495
496 send_compatible_streaming_request(self.client.clone(), req)
497 .instrument(span)
498 .await
499 }
500}
501
502impl From<ApiErrorResponse> for CompletionError {
503 fn from(err: ApiErrorResponse) -> Self {
504 CompletionError::ProviderError(err.message)
505 }
506}
507
508impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
509 type Error = CompletionError;
510
511 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
512 let (content, usage) = match &response {
513 CompletionResponse::Structured { choices, usage, .. } => {
514 let choice = choices.first().ok_or_else(|| {
515 CompletionError::ResponseError("Response contained no choices".to_owned())
516 })?;
517
518 let usage = usage
519 .as_ref()
520 .map(|usage| completion::Usage {
521 input_tokens: usage.prompt_tokens as u64,
522 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
523 total_tokens: usage.total_tokens as u64,
524 cached_input_tokens: 0,
525 })
526 .unwrap_or_default();
527
528 let message = message::Message::try_from(choice.message.clone())?;
530
531 let content = match message {
532 Message::Assistant { content, .. } => {
533 if content.is_empty() {
534 return Err(CompletionError::ResponseError(
535 "Response contained empty content".to_owned(),
536 ));
537 }
538
539 for c in content.iter() {
541 if !matches!(c, AssistantContent::Text(_)) {
542 tracing::warn!(target: "rig",
543 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
544 );
545 }
546 }
547
548 content.iter().map(|c| {
549 match c {
550 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
551 other => Err(CompletionError::ResponseError(
552 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
553 ))
554 }
555 }).collect::<Result<Vec<_>, _>>()?
556 }
557 Message::User { .. } => {
558 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
559 return Err(CompletionError::ResponseError(
560 "Received user message in response where assistant message was expected".to_owned()
561 ));
562 }
563 Message::System { .. } => {
564 tracing::warn!(target: "rig", "Received system message in response where assistant message was expected");
565 return Err(CompletionError::ResponseError(
566 "Received system message in response where assistant message was expected".to_owned(),
567 ));
568 }
569 };
570
571 (content, usage)
572 }
573 CompletionResponse::Simple(text) => (
574 vec![completion::AssistantContent::text(text)],
575 completion::Usage::new(),
576 ),
577 };
578
579 let choice = OneOrMany::many(content).map_err(|_| {
580 CompletionError::ResponseError(
581 "Response contained no message or tool call (empty)".to_owned(),
582 )
583 })?;
584
585 Ok(completion::CompletionResponse {
586 choice,
587 usage,
588 raw_response: response,
589 message_id: None,
590 })
591 }
592}
593
594#[derive(Clone, Debug, Deserialize, Serialize)]
595pub struct Usage {
596 pub prompt_tokens: usize,
597 pub total_tokens: usize,
598}
599
600impl std::fmt::Display for Usage {
601 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 write!(
603 f,
604 "Prompt tokens: {} Total tokens: {}",
605 self.prompt_tokens, self.total_tokens
606 )
607 }
608}
609
610impl From<Message> for serde_json::Value {
611 fn from(msg: Message) -> Self {
612 match msg {
613 Message::System { content } => serde_json::json!({
614 "role": "system",
615 "content": content
616 }),
617 Message::User { content } => {
618 let text = content
619 .iter()
620 .map(|c| match c {
621 UserContent::Text(text) => &text.text,
622 _ => "",
623 })
624 .collect::<Vec<_>>()
625 .join("\n");
626 serde_json::json!({
627 "role": "user",
628 "content": text
629 })
630 }
631 Message::Assistant { content, .. } => {
632 let text = content
633 .iter()
634 .map(|c| match c {
635 AssistantContent::Text(text) => &text.text,
636 _ => "",
637 })
638 .collect::<Vec<_>>()
639 .join("\n");
640 serde_json::json!({
641 "role": "assistant",
642 "content": text
643 })
644 }
645 }
646 }
647}
648
649impl TryFrom<serde_json::Value> for Message {
650 type Error = CompletionError;
651
652 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
653 let role = value["role"].as_str().ok_or_else(|| {
654 CompletionError::ResponseError("Message missing role field".to_owned())
655 })?;
656
657 let content = match value.get("content") {
659 Some(content) => match content {
660 serde_json::Value::String(s) => s.clone(),
661 serde_json::Value::Array(arr) => arr
662 .iter()
663 .filter_map(|c| {
664 c.get("text")
665 .and_then(|t| t.as_str())
666 .map(|text| text.to_string())
667 })
668 .collect::<Vec<_>>()
669 .join("\n"),
670 _ => {
671 return Err(CompletionError::ResponseError(
672 "Message content must be string or array".to_owned(),
673 ));
674 }
675 },
676 None => {
677 return Err(CompletionError::ResponseError(
678 "Message missing content field".to_owned(),
679 ));
680 }
681 };
682
683 match role {
684 "system" => Ok(Message::System { content }),
685 "user" => Ok(Message::User {
686 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
687 }),
688 "assistant" => Ok(Message::Assistant {
689 id: None,
690 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
691 }),
692 _ => Err(CompletionError::ResponseError(format!(
693 "Unsupported message role: {role}"
694 ))),
695 }
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702 use crate::message::UserContent;
703 use serde_json::json;
704
705 #[test]
706 fn test_deserialize_message() {
707 let assistant_message_json = json!({
709 "role": "assistant",
710 "content": "Hello there, how may I assist you today?"
711 });
712
713 let user_message_json = json!({
714 "role": "user",
715 "content": "What can you help me with?"
716 });
717
718 let assistant_message_array_json = json!({
720 "role": "assistant",
721 "content": [{
722 "type": "text",
723 "text": "Hello there, how may I assist you today?"
724 }]
725 });
726
727 let assistant_message = Message::try_from(assistant_message_json).unwrap();
728 let user_message = Message::try_from(user_message_json).unwrap();
729 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
730
731 match assistant_message {
733 Message::Assistant { content, .. } => {
734 assert_eq!(
735 content.first(),
736 AssistantContent::Text(message::Text {
737 text: "Hello there, how may I assist you today?".to_string()
738 })
739 );
740 }
741 _ => panic!("Expected assistant message"),
742 }
743
744 match user_message {
745 Message::User { content } => {
746 assert_eq!(
747 content.first(),
748 UserContent::Text(message::Text {
749 text: "What can you help me with?".to_string()
750 })
751 );
752 }
753 _ => panic!("Expected user message"),
754 }
755
756 match assistant_message_array {
758 Message::Assistant { content, .. } => {
759 assert_eq!(
760 content.first(),
761 AssistantContent::Text(message::Text {
762 text: "Hello there, how may I assist you today?".to_string()
763 })
764 );
765 }
766 _ => panic!("Expected assistant message"),
767 }
768 }
769
770 #[test]
771 fn test_message_conversion() {
772 let original_message = message::Message::User {
774 content: OneOrMany::one(message::UserContent::text("Hello")),
775 };
776
777 let mira_value: serde_json::Value = original_message.clone().into();
779
780 let converted_message: Message = mira_value.try_into().unwrap();
782
783 assert_eq!(original_message, converted_message);
784 }
785
786 #[test]
787 fn test_completion_response_conversion() {
788 let mira_response = CompletionResponse::Structured {
789 id: "resp_123".to_string(),
790 object: "chat.completion".to_string(),
791 created: 1234567890,
792 model: "deepseek-r1".to_string(),
793 choices: vec![ChatChoice {
794 message: RawMessage {
795 role: "assistant".to_string(),
796 content: "Test response".to_string(),
797 },
798 finish_reason: Some("stop".to_string()),
799 index: Some(0),
800 }],
801 usage: Some(Usage {
802 prompt_tokens: 10,
803 total_tokens: 20,
804 }),
805 };
806
807 let completion_response: completion::CompletionResponse<CompletionResponse> =
808 mira_response.try_into().unwrap();
809
810 assert_eq!(
811 completion_response.choice.first(),
812 completion::AssistantContent::text("Test response")
813 );
814 }
815 #[test]
816 fn test_client_initialization() {
817 let _client =
818 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
819 let _client_from_builder = crate::providers::mira::Client::builder()
820 .api_key("dummy-key")
821 .build()
822 .expect("Client::builder() failed");
823 }
824}