1use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
11use crate::http_client::{self, HttpClientExt};
12use crate::json_utils::merge;
13use crate::message::{Document, DocumentSourceKind};
14use crate::providers::openai;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::StreamingCompletionResponse;
17use crate::{
18 OneOrMany,
19 completion::{self, CompletionError, CompletionRequest},
20 impl_conversion_traits,
21 message::{self, AssistantContent, Message, UserContent},
22};
23use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use std::string::FromUtf8Error;
27use thiserror::Error;
28use tracing::{self, Instrument, info_span};
29
30#[derive(Debug, Error)]
31pub enum MiraError {
32 #[error("Invalid API key")]
33 InvalidApiKey,
34 #[error("API error: {0}")]
35 ApiError(u16),
36 #[error("Request error: {0}")]
37 RequestError(#[from] http_client::Error),
38 #[error("UTF-8 error: {0}")]
39 Utf8Error(#[from] FromUtf8Error),
40 #[error("JSON error: {0}")]
41 JsonError(#[from] serde_json::Error),
42}
43
44#[derive(Debug, Deserialize)]
45struct ApiErrorResponse {
46 message: String,
47}
48
49#[derive(Debug, Deserialize, Clone, Serialize)]
50pub struct RawMessage {
51 pub role: String,
52 pub content: String,
53}
54
55const MIRA_API_BASE_URL: &str = "https://api.mira.network";
56
57impl TryFrom<RawMessage> for message::Message {
58 type Error = CompletionError;
59
60 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
61 match raw.role.as_str() {
62 "user" => Ok(message::Message::User {
63 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
64 }),
65 "assistant" => Ok(message::Message::Assistant {
66 id: None,
67 content: OneOrMany::one(AssistantContent::Text(message::Text {
68 text: raw.content,
69 })),
70 }),
71 _ => Err(CompletionError::ResponseError(format!(
72 "Unsupported message role: {}",
73 raw.role
74 ))),
75 }
76 }
77}
78
79#[derive(Debug, Deserialize, Serialize)]
80#[serde(untagged)]
81pub enum CompletionResponse {
82 Structured {
83 id: String,
84 object: String,
85 created: u64,
86 model: String,
87 choices: Vec<ChatChoice>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 usage: Option<Usage>,
90 },
91 Simple(String),
92}
93
94#[derive(Debug, Deserialize, Serialize)]
95pub struct ChatChoice {
96 pub message: RawMessage,
97 #[serde(default)]
98 pub finish_reason: Option<String>,
99 #[serde(default)]
100 pub index: Option<usize>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104struct ModelsResponse {
105 data: Vec<ModelInfo>,
106}
107
108#[derive(Debug, Deserialize, Serialize)]
109struct ModelInfo {
110 id: String,
111}
112
113pub struct ClientBuilder<'a, T = reqwest::Client> {
114 api_key: &'a str,
115 base_url: &'a str,
116 http_client: T,
117}
118
119impl<'a, T> ClientBuilder<'a, T>
120where
121 T: Default,
122{
123 pub fn new(api_key: &'a str) -> Self {
124 Self {
125 api_key,
126 base_url: MIRA_API_BASE_URL,
127 http_client: Default::default(),
128 }
129 }
130}
131
132impl<'a, T> ClientBuilder<'a, T> {
133 pub fn base_url(mut self, base_url: &'a str) -> Self {
134 self.base_url = base_url;
135 self
136 }
137
138 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
139 ClientBuilder {
140 api_key: self.api_key,
141 base_url: self.base_url,
142 http_client,
143 }
144 }
145
146 pub fn build(self) -> Client<T> {
147 let mut headers = HeaderMap::new();
148 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
149 headers.insert(
150 reqwest::header::ACCEPT,
151 HeaderValue::from_static("application/json"),
152 );
153 headers.insert(
154 reqwest::header::USER_AGENT,
155 HeaderValue::from_static("rig-client/1.0"),
156 );
157
158 Client {
159 base_url: self.base_url.to_string(),
160 http_client: self.http_client,
161 api_key: self.api_key.to_string(),
162 headers,
163 }
164 }
165}
166
167#[derive(Clone)]
168pub struct Client<T = reqwest::Client> {
170 base_url: String,
171 http_client: T,
172 api_key: String,
173 headers: HeaderMap,
174}
175
176impl<T> std::fmt::Debug for Client<T>
177where
178 T: std::fmt::Debug,
179{
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("Client")
182 .field("base_url", &self.base_url)
183 .field("http_client", &self.http_client)
184 .field("api_key", &"<REDACTED>")
185 .field("headers", &self.headers)
186 .finish()
187 }
188}
189
190impl<T> Client<T>
191where
192 T: Default,
193{
194 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
205 ClientBuilder::new(api_key)
206 }
207
208 pub fn new(api_key: &str) -> Self {
213 Self::builder(api_key).build()
214 }
215}
216
217impl<T> Client<T>
218where
219 T: HttpClientExt,
220{
221 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
223 let req = self.get("/v1/models").and_then(|req| {
224 req.body(http_client::NoBody)
225 .map_err(http_client::Error::Protocol)
226 })?;
227
228 let response = self.http_client.send(req).await?;
229
230 let status = response.status();
231
232 if !status.is_success() {
233 let error_text = http_client::text(response).await.unwrap_or_default();
235 tracing::error!("Error response: {}", error_text);
236 return Err(MiraError::ApiError(status.as_u16()));
237 }
238
239 let response_text = http_client::text(response).await?;
240
241 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
242 tracing::error!("Failed to parse response: {}", e);
243 MiraError::JsonError(e)
244 })?;
245
246 Ok(models.data.into_iter().map(|model| model.id).collect())
247 }
248
249 fn req(
250 &self,
251 method: http_client::Method,
252 path: &str,
253 ) -> http_client::Result<http_client::Builder> {
254 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
255
256 let mut req = http_client::Builder::new().method(method).uri(url);
257
258 if let Some(hs) = req.headers_mut() {
259 *hs = self.headers.clone();
260 }
261
262 http_client::with_bearer_auth(req, &self.api_key)
263 }
264
265 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
266 self.req(http_client::Method::POST, path)
267 }
268}
269
270impl Client<reqwest::Client> {
271 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
272 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
273
274 self.http_client
275 .post(url)
276 .bearer_auth(&self.api_key)
277 .headers(self.headers.clone())
278 }
279}
280
281impl ProviderClient for Client<reqwest::Client> {
282 fn from_env() -> Self {
285 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
286 Self::new(&api_key)
287 }
288
289 fn from_val(input: crate::client::ProviderValue) -> Self {
290 let crate::client::ProviderValue::Simple(api_key) = input else {
291 panic!("Incorrect provider value type")
292 };
293 Self::new(&api_key)
294 }
295}
296
297impl CompletionClient for Client<reqwest::Client> {
298 type CompletionModel = CompletionModel<reqwest::Client>;
299 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
301 CompletionModel::new(self.to_owned(), model)
302 }
303}
304
305impl VerifyClient for Client<reqwest::Client> {
306 #[cfg_attr(feature = "worker", worker::send)]
307 async fn verify(&self) -> Result<(), VerifyError> {
308 let req = self
309 .get("/user-credits")?
310 .body(http_client::NoBody)
311 .map_err(http_client::Error::from)?;
312
313 let response = HttpClientExt::send(&self.http_client, req).await?;
314
315 match response.status() {
316 reqwest::StatusCode::OK => Ok(()),
317 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
318 reqwest::StatusCode::INTERNAL_SERVER_ERROR
319 | reqwest::StatusCode::SERVICE_UNAVAILABLE
320 | reqwest::StatusCode::BAD_GATEWAY => {
321 let text = http_client::text(response).await?;
322 Err(VerifyError::ProviderError(text))
323 }
324 _ => {
325 Ok(())
327 }
328 }
329 }
330}
331
332impl_conversion_traits!(
333 AsEmbeddings,
334 AsTranscription,
335 AsImageGeneration,
336 AsAudioGeneration for Client<T>
337);
338
339#[derive(Clone)]
340pub struct CompletionModel<T> {
341 client: Client<T>,
342 pub model: String,
344}
345
346impl<T> CompletionModel<T> {
347 pub fn new(client: Client<T>, model: &str) -> Self {
348 Self {
349 client,
350 model: model.to_string(),
351 }
352 }
353
354 fn create_completion_request(
355 &self,
356 completion_request: CompletionRequest,
357 ) -> Result<Value, CompletionError> {
358 if completion_request.tool_choice.is_some() {
359 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
360 }
361
362 let mut messages = Vec::new();
363
364 if let Some(preamble) = &completion_request.preamble {
366 messages.push(serde_json::json!({
367 "role": "user",
368 "content": preamble.to_string()
369 }));
370 }
371
372 if let Some(Message::User { content }) = completion_request.normalized_documents() {
374 let text = content
375 .into_iter()
376 .filter_map(|doc| match doc {
377 UserContent::Document(Document {
378 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
379 ..
380 }) => Some(data),
381 UserContent::Text(text) => Some(text.text),
382
383 _ => None,
385 })
386 .collect::<Vec<_>>()
387 .join("\n");
388
389 messages.push(serde_json::json!({
390 "role": "user",
391 "content": text
392 }));
393 }
394
395 for msg in completion_request.chat_history {
397 let (role, content) = match msg {
398 Message::User { content } => {
399 let text = content
400 .iter()
401 .map(|c| match c {
402 UserContent::Text(text) => &text.text,
403 _ => "",
404 })
405 .collect::<Vec<_>>()
406 .join("\n");
407 ("user", text)
408 }
409 Message::Assistant { content, .. } => {
410 let text = content
411 .iter()
412 .map(|c| match c {
413 AssistantContent::Text(text) => &text.text,
414 _ => "",
415 })
416 .collect::<Vec<_>>()
417 .join("\n");
418 ("assistant", text)
419 }
420 };
421 messages.push(serde_json::json!({
422 "role": role,
423 "content": content
424 }));
425 }
426
427 let request = serde_json::json!({
428 "model": self.model,
429 "messages": messages,
430 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
431 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
432 "stream": false
433 });
434
435 Ok(request)
436 }
437}
438
439impl completion::CompletionModel for CompletionModel<reqwest::Client> {
440 type Response = CompletionResponse;
441 type StreamingResponse = openai::StreamingCompletionResponse;
442
443 #[cfg_attr(feature = "worker", worker::send)]
444 async fn completion(
445 &self,
446 completion_request: CompletionRequest,
447 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
448 if !completion_request.tools.is_empty() {
449 tracing::warn!(target: "rig::completions",
450 "Tool calls are not supported by the Mira provider. {len} tools will be ignored.",
451 len = completion_request.tools.len()
452 );
453 }
454
455 let preamble = completion_request.preamble.clone();
456
457 let request = self.create_completion_request(completion_request)?;
458
459 let span = if tracing::Span::current().is_disabled() {
460 info_span!(
461 target: "rig::completions",
462 "chat",
463 gen_ai.operation.name = "chat",
464 gen_ai.provider.name = "mira",
465 gen_ai.request.model = self.model,
466 gen_ai.system_instructions = preamble,
467 gen_ai.response.id = tracing::field::Empty,
468 gen_ai.response.model = tracing::field::Empty,
469 gen_ai.usage.output_tokens = tracing::field::Empty,
470 gen_ai.usage.input_tokens = tracing::field::Empty,
471 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
472 gen_ai.output.messages = tracing::field::Empty,
473 )
474 } else {
475 tracing::Span::current()
476 };
477
478 let async_block = async move {
479 let response = self
480 .client
481 .reqwest_post("/v1/chat/completions")
482 .json(&request)
483 .send()
484 .await
485 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
486
487 if !response.status().is_success() {
488 let status = response.status().as_u16();
489 let error_text = response.text().await.unwrap_or_default();
490 return Err(CompletionError::ProviderError(format!(
491 "API error: {status} - {error_text}"
492 )));
493 }
494
495 let response: CompletionResponse = response
496 .json()
497 .await
498 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
499
500 if let CompletionResponse::Structured {
501 id,
502 model,
503 choices,
504 usage,
505 ..
506 } = &response
507 {
508 let span = tracing::Span::current();
509 span.record("gen_ai.response.model_name", model);
510 span.record("gen_ai.response.id", id);
511 span.record(
512 "gen_ai.output.messages",
513 serde_json::to_string(choices).unwrap(),
514 );
515 if let Some(usage) = usage {
516 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
517 span.record(
518 "gen_ai.usage.output_tokens",
519 usage.total_tokens - usage.prompt_tokens,
520 );
521 }
522 }
523
524 response.try_into()
525 };
526
527 async_block.instrument(span).await
528 }
529
530 #[cfg_attr(feature = "worker", worker::send)]
531 async fn stream(
532 &self,
533 completion_request: CompletionRequest,
534 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
535 let preamble = completion_request.preamble.clone();
536 let mut request = self.create_completion_request(completion_request)?;
537
538 let span = if tracing::Span::current().is_disabled() {
539 info_span!(
540 target: "rig::completions",
541 "chat_streaming",
542 gen_ai.operation.name = "chat_streaming",
543 gen_ai.provider.name = "mira",
544 gen_ai.request.model = self.model,
545 gen_ai.system_instructions = preamble,
546 gen_ai.response.id = tracing::field::Empty,
547 gen_ai.response.model = tracing::field::Empty,
548 gen_ai.usage.output_tokens = tracing::field::Empty,
549 gen_ai.usage.input_tokens = tracing::field::Empty,
550 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
551 gen_ai.output.messages = tracing::field::Empty,
552 )
553 } else {
554 tracing::Span::current()
555 };
556 request = merge(request, json!({"stream": true}));
557
558 let builder = self
559 .client
560 .reqwest_post("/v1/chat/completions")
561 .json(&request);
562
563 send_compatible_streaming_request(builder)
564 .instrument(span)
565 .await
566 }
567}
568
569impl From<ApiErrorResponse> for CompletionError {
570 fn from(err: ApiErrorResponse) -> Self {
571 CompletionError::ProviderError(err.message)
572 }
573}
574
575impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
576 type Error = CompletionError;
577
578 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
579 let (content, usage) = match &response {
580 CompletionResponse::Structured { choices, usage, .. } => {
581 let choice = choices.first().ok_or_else(|| {
582 CompletionError::ResponseError("Response contained no choices".to_owned())
583 })?;
584
585 let usage = usage
586 .as_ref()
587 .map(|usage| completion::Usage {
588 input_tokens: usage.prompt_tokens as u64,
589 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
590 total_tokens: usage.total_tokens as u64,
591 })
592 .unwrap_or_default();
593
594 let message = message::Message::try_from(choice.message.clone())?;
596
597 let content = match message {
598 Message::Assistant { content, .. } => {
599 if content.is_empty() {
600 return Err(CompletionError::ResponseError(
601 "Response contained empty content".to_owned(),
602 ));
603 }
604
605 for c in content.iter() {
607 if !matches!(c, AssistantContent::Text(_)) {
608 tracing::warn!(target: "rig",
609 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
610 );
611 }
612 }
613
614 content.iter().map(|c| {
615 match c {
616 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
617 other => Err(CompletionError::ResponseError(
618 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
619 ))
620 }
621 }).collect::<Result<Vec<_>, _>>()?
622 }
623 Message::User { .. } => {
624 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
625 return Err(CompletionError::ResponseError(
626 "Received user message in response where assistant message was expected".to_owned()
627 ));
628 }
629 };
630
631 (content, usage)
632 }
633 CompletionResponse::Simple(text) => (
634 vec![completion::AssistantContent::text(text)],
635 completion::Usage::new(),
636 ),
637 };
638
639 let choice = OneOrMany::many(content).map_err(|_| {
640 CompletionError::ResponseError(
641 "Response contained no message or tool call (empty)".to_owned(),
642 )
643 })?;
644
645 Ok(completion::CompletionResponse {
646 choice,
647 usage,
648 raw_response: response,
649 })
650 }
651}
652
653#[derive(Clone, Debug, Deserialize, Serialize)]
654pub struct Usage {
655 pub prompt_tokens: usize,
656 pub total_tokens: usize,
657}
658
659impl std::fmt::Display for Usage {
660 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
661 write!(
662 f,
663 "Prompt tokens: {} Total tokens: {}",
664 self.prompt_tokens, self.total_tokens
665 )
666 }
667}
668
669impl From<Message> for serde_json::Value {
670 fn from(msg: Message) -> Self {
671 match msg {
672 Message::User { content } => {
673 let text = content
674 .iter()
675 .map(|c| match c {
676 UserContent::Text(text) => &text.text,
677 _ => "",
678 })
679 .collect::<Vec<_>>()
680 .join("\n");
681 serde_json::json!({
682 "role": "user",
683 "content": text
684 })
685 }
686 Message::Assistant { content, .. } => {
687 let text = content
688 .iter()
689 .map(|c| match c {
690 AssistantContent::Text(text) => &text.text,
691 _ => "",
692 })
693 .collect::<Vec<_>>()
694 .join("\n");
695 serde_json::json!({
696 "role": "assistant",
697 "content": text
698 })
699 }
700 }
701 }
702}
703
704impl TryFrom<serde_json::Value> for Message {
705 type Error = CompletionError;
706
707 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
708 let role = value["role"].as_str().ok_or_else(|| {
709 CompletionError::ResponseError("Message missing role field".to_owned())
710 })?;
711
712 let content = match value.get("content") {
714 Some(content) => match content {
715 serde_json::Value::String(s) => s.clone(),
716 serde_json::Value::Array(arr) => arr
717 .iter()
718 .filter_map(|c| {
719 c.get("text")
720 .and_then(|t| t.as_str())
721 .map(|text| text.to_string())
722 })
723 .collect::<Vec<_>>()
724 .join("\n"),
725 _ => {
726 return Err(CompletionError::ResponseError(
727 "Message content must be string or array".to_owned(),
728 ));
729 }
730 },
731 None => {
732 return Err(CompletionError::ResponseError(
733 "Message missing content field".to_owned(),
734 ));
735 }
736 };
737
738 match role {
739 "user" => Ok(Message::User {
740 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
741 }),
742 "assistant" => Ok(Message::Assistant {
743 id: None,
744 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
745 }),
746 _ => Err(CompletionError::ResponseError(format!(
747 "Unsupported message role: {role}"
748 ))),
749 }
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use crate::message::UserContent;
757 use serde_json::json;
758
759 #[test]
760 fn test_deserialize_message() {
761 let assistant_message_json = json!({
763 "role": "assistant",
764 "content": "Hello there, how may I assist you today?"
765 });
766
767 let user_message_json = json!({
768 "role": "user",
769 "content": "What can you help me with?"
770 });
771
772 let assistant_message_array_json = json!({
774 "role": "assistant",
775 "content": [{
776 "type": "text",
777 "text": "Hello there, how may I assist you today?"
778 }]
779 });
780
781 let assistant_message = Message::try_from(assistant_message_json).unwrap();
782 let user_message = Message::try_from(user_message_json).unwrap();
783 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
784
785 match assistant_message {
787 Message::Assistant { content, .. } => {
788 assert_eq!(
789 content.first(),
790 AssistantContent::Text(message::Text {
791 text: "Hello there, how may I assist you today?".to_string()
792 })
793 );
794 }
795 _ => panic!("Expected assistant message"),
796 }
797
798 match user_message {
799 Message::User { content } => {
800 assert_eq!(
801 content.first(),
802 UserContent::Text(message::Text {
803 text: "What can you help me with?".to_string()
804 })
805 );
806 }
807 _ => panic!("Expected user message"),
808 }
809
810 match assistant_message_array {
812 Message::Assistant { content, .. } => {
813 assert_eq!(
814 content.first(),
815 AssistantContent::Text(message::Text {
816 text: "Hello there, how may I assist you today?".to_string()
817 })
818 );
819 }
820 _ => panic!("Expected assistant message"),
821 }
822 }
823
824 #[test]
825 fn test_message_conversion() {
826 let original_message = message::Message::User {
828 content: OneOrMany::one(message::UserContent::text("Hello")),
829 };
830
831 let mira_value: serde_json::Value = original_message.clone().into();
833
834 let converted_message: Message = mira_value.try_into().unwrap();
836
837 assert_eq!(original_message, converted_message);
838 }
839
840 #[test]
841 fn test_completion_response_conversion() {
842 let mira_response = CompletionResponse::Structured {
843 id: "resp_123".to_string(),
844 object: "chat.completion".to_string(),
845 created: 1234567890,
846 model: "deepseek-r1".to_string(),
847 choices: vec![ChatChoice {
848 message: RawMessage {
849 role: "assistant".to_string(),
850 content: "Test response".to_string(),
851 },
852 finish_reason: Some("stop".to_string()),
853 index: Some(0),
854 }],
855 usage: Some(Usage {
856 prompt_tokens: 10,
857 total_tokens: 20,
858 }),
859 };
860
861 let completion_response: completion::CompletionResponse<CompletionResponse> =
862 mira_response.try_into().unwrap();
863
864 assert_eq!(
865 completion_response.choice.first(),
866 completion::AssistantContent::text("Test response")
867 );
868 }
869}