1use crate::json_utils::merge;
11use crate::providers::openai::send_compatible_streaming_request;
12use crate::streaming::{StreamingCompletionModel, StreamingResult};
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError, CompletionRequest},
16 extractor::ExtractorBuilder,
17 message::{self, AssistantContent, Message, UserContent},
18 OneOrMany,
19};
20use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 content: OneOrMany::one(AssistantContent::Text(message::Text {
65 text: raw.content,
66 })),
67 }),
68 _ => Err(CompletionError::ResponseError(format!(
69 "Unsupported message role: {}",
70 raw.role
71 ))),
72 }
73 }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79 Structured {
80 id: String,
81 object: String,
82 created: u64,
83 model: String,
84 choices: Vec<ChatChoice>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 usage: Option<Usage>,
87 },
88 Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93 pub message: RawMessage,
94 #[serde(default)]
95 pub finish_reason: Option<String>,
96 #[serde(default)]
97 pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102 data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107 id: String,
108}
109
110#[derive(Clone)]
111pub struct Client {
113 base_url: String,
114 client: reqwest::Client,
115 headers: HeaderMap,
116}
117
118impl Client {
119 pub fn new(api_key: &str) -> Result<Self, MiraError> {
121 let mut headers = HeaderMap::new();
122 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123 headers.insert(
124 AUTHORIZATION,
125 HeaderValue::from_str(&format!("Bearer {}", api_key))
126 .map_err(|_| MiraError::InvalidApiKey)?,
127 );
128 headers.insert(
129 reqwest::header::ACCEPT,
130 HeaderValue::from_static("application/json"),
131 );
132 headers.insert(
133 reqwest::header::USER_AGENT,
134 HeaderValue::from_static("rig-client/1.0"),
135 );
136
137 Ok(Self {
138 base_url: MIRA_API_BASE_URL.to_string(),
139 client: reqwest::Client::builder()
140 .build()
141 .expect("Failed to build HTTP client"),
142 headers,
143 })
144 }
145
146 pub fn from_env() -> Result<Self, MiraError> {
149 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
150 Self::new(&api_key)
151 }
152
153 pub fn new_with_base_url(
155 api_key: &str,
156 base_url: impl Into<String>,
157 ) -> Result<Self, MiraError> {
158 let mut client = Self::new(api_key)?;
159 client.base_url = base_url.into();
160 Ok(client)
161 }
162
163 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165 let url = format!("{}/v1/models", self.base_url);
166
167 let response = self
168 .client
169 .get(&url)
170 .headers(self.headers.clone())
171 .send()
172 .await?;
173
174 let status = response.status();
175
176 if !status.is_success() {
177 let _error_text = response.text().await.unwrap_or_default();
179 tracing::error!("Error response: {}", _error_text);
180 return Err(MiraError::ApiError(status.as_u16()));
181 }
182
183 let response_text = response.text().await?;
184
185 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
186 tracing::error!("Failed to parse response: {}", e);
187 MiraError::JsonError(e)
188 })?;
189
190 Ok(models.data.into_iter().map(|model| model.id).collect())
191 }
192
193 pub fn completion_model(&self, model: &str) -> CompletionModel {
195 CompletionModel::new(self.to_owned(), model)
196 }
197
198 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
200 AgentBuilder::new(self.completion_model(model))
201 }
202
203 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
205 &self,
206 model: &str,
207 ) -> ExtractorBuilder<T, CompletionModel> {
208 ExtractorBuilder::new(self.completion_model(model))
209 }
210}
211
212#[derive(Clone)]
213pub struct CompletionModel {
214 client: Client,
215 pub model: String,
217}
218
219impl CompletionModel {
220 pub fn new(client: Client, model: &str) -> Self {
221 Self {
222 client,
223 model: model.to_string(),
224 }
225 }
226
227 fn create_completion_request(
228 &self,
229 completion_request: CompletionRequest,
230 ) -> Result<Value, CompletionError> {
231 let mut messages = Vec::new();
232
233 if let Some(preamble) = &completion_request.preamble {
235 messages.push(serde_json::json!({
236 "role": "user",
237 "content": preamble.to_string()
238 }));
239 }
240
241 if let Some(Message::User { content }) = completion_request.normalized_documents() {
243 let text = content
244 .into_iter()
245 .filter_map(|doc| match doc {
246 UserContent::Document(doc) => Some(doc.data),
247 UserContent::Text(text) => Some(text.text),
248
249 _ => None,
251 })
252 .collect::<Vec<_>>()
253 .join("\n");
254
255 messages.push(serde_json::json!({
256 "role": "user",
257 "content": text
258 }));
259 }
260
261 for msg in completion_request.chat_history {
263 let (role, content) = match msg {
264 Message::User { content } => {
265 let text = content
266 .iter()
267 .map(|c| match c {
268 UserContent::Text(text) => &text.text,
269 _ => "",
270 })
271 .collect::<Vec<_>>()
272 .join("\n");
273 ("user", text)
274 }
275 Message::Assistant { content } => {
276 let text = content
277 .iter()
278 .map(|c| match c {
279 AssistantContent::Text(text) => &text.text,
280 _ => "",
281 })
282 .collect::<Vec<_>>()
283 .join("\n");
284 ("assistant", text)
285 }
286 };
287 messages.push(serde_json::json!({
288 "role": role,
289 "content": content
290 }));
291 }
292
293 let request = serde_json::json!({
294 "model": self.model,
295 "messages": messages,
296 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
297 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
298 "stream": false
299 });
300
301 Ok(request)
302 }
303}
304
305impl completion::CompletionModel for CompletionModel {
306 type Response = CompletionResponse;
307
308 #[cfg_attr(feature = "worker", worker::send)]
309 async fn completion(
310 &self,
311 completion_request: CompletionRequest,
312 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
313 if !completion_request.tools.is_empty() {
314 tracing::warn!(target: "rig",
315 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
316 completion_request.tools.len()
317 );
318 }
319
320 let mira_request = self.create_completion_request(completion_request)?;
321
322 let response = self
323 .client
324 .client
325 .post(format!("{}/v1/chat/completions", self.client.base_url))
326 .headers(self.client.headers.clone())
327 .json(&mira_request)
328 .send()
329 .await
330 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
331
332 if !response.status().is_success() {
333 let status = response.status().as_u16();
334 let error_text = response.text().await.unwrap_or_default();
335 return Err(CompletionError::ProviderError(format!(
336 "API error: {} - {}",
337 status, error_text
338 )));
339 }
340
341 let response: CompletionResponse = response
342 .json()
343 .await
344 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
345
346 response.try_into()
347 }
348}
349
350impl StreamingCompletionModel for CompletionModel {
351 async fn stream(
352 &self,
353 completion_request: CompletionRequest,
354 ) -> Result<StreamingResult, CompletionError> {
355 let mut request = self.create_completion_request(completion_request)?;
356
357 request = merge(request, json!({"stream": true}));
358
359 let builder = self
360 .client
361 .client
362 .post(format!("{}/v1/chat/completions", self.client.base_url))
363 .headers(self.client.headers.clone())
364 .json(&request);
365
366 send_compatible_streaming_request(builder).await
367 }
368}
369
370impl From<ApiErrorResponse> for CompletionError {
371 fn from(err: ApiErrorResponse) -> Self {
372 CompletionError::ProviderError(err.message)
373 }
374}
375
376impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
377 type Error = CompletionError;
378
379 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
380 let content = match &response {
381 CompletionResponse::Structured { choices, .. } => {
382 let choice = choices.first().ok_or_else(|| {
383 CompletionError::ResponseError("Response contained no choices".to_owned())
384 })?;
385
386 let message = message::Message::try_from(choice.message.clone())?;
388
389 match message {
390 Message::Assistant { content } => {
391 if content.is_empty() {
392 return Err(CompletionError::ResponseError(
393 "Response contained empty content".to_owned(),
394 ));
395 }
396
397 for c in content.iter() {
399 if !matches!(c, AssistantContent::Text(_)) {
400 tracing::warn!(target: "rig",
401 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
402 );
403 }
404 }
405
406 content.iter().map(|c| {
407 match c {
408 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
409 other => Err(CompletionError::ResponseError(
410 format!("Unsupported content type: {:?}. The Mira provider currently only supports text content", other)
411 ))
412 }
413 }).collect::<Result<Vec<_>, _>>()?
414 }
415 Message::User { .. } => {
416 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
417 return Err(CompletionError::ResponseError(
418 "Received user message in response where assistant message was expected".to_owned()
419 ));
420 }
421 }
422 }
423 CompletionResponse::Simple(text) => {
424 vec![completion::AssistantContent::text(text)]
425 }
426 };
427
428 let choice = OneOrMany::many(content).map_err(|_| {
429 CompletionError::ResponseError(
430 "Response contained no message or tool call (empty)".to_owned(),
431 )
432 })?;
433
434 Ok(completion::CompletionResponse {
435 choice,
436 raw_response: response,
437 })
438 }
439}
440
441#[derive(Clone, Debug, Deserialize)]
442pub struct Usage {
443 pub prompt_tokens: usize,
444 pub total_tokens: usize,
445}
446
447impl std::fmt::Display for Usage {
448 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449 write!(
450 f,
451 "Prompt tokens: {} Total tokens: {}",
452 self.prompt_tokens, self.total_tokens
453 )
454 }
455}
456
457impl From<Message> for serde_json::Value {
458 fn from(msg: Message) -> Self {
459 match msg {
460 Message::User { content } => {
461 let text = content
462 .iter()
463 .map(|c| match c {
464 UserContent::Text(text) => &text.text,
465 _ => "",
466 })
467 .collect::<Vec<_>>()
468 .join("\n");
469 serde_json::json!({
470 "role": "user",
471 "content": text
472 })
473 }
474 Message::Assistant { content } => {
475 let text = content
476 .iter()
477 .map(|c| match c {
478 AssistantContent::Text(text) => &text.text,
479 _ => "",
480 })
481 .collect::<Vec<_>>()
482 .join("\n");
483 serde_json::json!({
484 "role": "assistant",
485 "content": text
486 })
487 }
488 }
489 }
490}
491
492impl TryFrom<serde_json::Value> for Message {
493 type Error = CompletionError;
494
495 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
496 let role = value["role"].as_str().ok_or_else(|| {
497 CompletionError::ResponseError("Message missing role field".to_owned())
498 })?;
499
500 let content = match value.get("content") {
502 Some(content) => match content {
503 serde_json::Value::String(s) => s.clone(),
504 serde_json::Value::Array(arr) => arr
505 .iter()
506 .filter_map(|c| {
507 c.get("text")
508 .and_then(|t| t.as_str())
509 .map(|text| text.to_string())
510 })
511 .collect::<Vec<_>>()
512 .join("\n"),
513 _ => {
514 return Err(CompletionError::ResponseError(
515 "Message content must be string or array".to_owned(),
516 ))
517 }
518 },
519 None => {
520 return Err(CompletionError::ResponseError(
521 "Message missing content field".to_owned(),
522 ))
523 }
524 };
525
526 match role {
527 "user" => Ok(Message::User {
528 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
529 }),
530 "assistant" => Ok(Message::Assistant {
531 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
532 }),
533 _ => Err(CompletionError::ResponseError(format!(
534 "Unsupported message role: {}",
535 role
536 ))),
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::message::UserContent;
545 use serde_json::json;
546
547 #[test]
548 fn test_deserialize_message() {
549 let assistant_message_json = json!({
551 "role": "assistant",
552 "content": "Hello there, how may I assist you today?"
553 });
554
555 let user_message_json = json!({
556 "role": "user",
557 "content": "What can you help me with?"
558 });
559
560 let assistant_message_array_json = json!({
562 "role": "assistant",
563 "content": [{
564 "type": "text",
565 "text": "Hello there, how may I assist you today?"
566 }]
567 });
568
569 let assistant_message = Message::try_from(assistant_message_json).unwrap();
570 let user_message = Message::try_from(user_message_json).unwrap();
571 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
572
573 match assistant_message {
575 Message::Assistant { content } => {
576 assert_eq!(
577 content.first(),
578 AssistantContent::Text(message::Text {
579 text: "Hello there, how may I assist you today?".to_string()
580 })
581 );
582 }
583 _ => panic!("Expected assistant message"),
584 }
585
586 match user_message {
587 Message::User { content } => {
588 assert_eq!(
589 content.first(),
590 UserContent::Text(message::Text {
591 text: "What can you help me with?".to_string()
592 })
593 );
594 }
595 _ => panic!("Expected user message"),
596 }
597
598 match assistant_message_array {
600 Message::Assistant { content } => {
601 assert_eq!(
602 content.first(),
603 AssistantContent::Text(message::Text {
604 text: "Hello there, how may I assist you today?".to_string()
605 })
606 );
607 }
608 _ => panic!("Expected assistant message"),
609 }
610 }
611
612 #[test]
613 fn test_message_conversion() {
614 let original_message = message::Message::User {
616 content: OneOrMany::one(message::UserContent::text("Hello")),
617 };
618
619 let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
621
622 let converted_message: Message = mira_value.try_into().unwrap();
624
625 let final_message: message::Message = converted_message.try_into().unwrap();
627
628 assert_eq!(original_message, final_message);
629 }
630
631 #[test]
632 fn test_completion_response_conversion() {
633 let mira_response = CompletionResponse::Structured {
634 id: "resp_123".to_string(),
635 object: "chat.completion".to_string(),
636 created: 1234567890,
637 model: "deepseek-r1".to_string(),
638 choices: vec![ChatChoice {
639 message: RawMessage {
640 role: "assistant".to_string(),
641 content: "Test response".to_string(),
642 },
643 finish_reason: Some("stop".to_string()),
644 index: Some(0),
645 }],
646 usage: Some(Usage {
647 prompt_tokens: 10,
648 total_tokens: 20,
649 }),
650 };
651
652 let completion_response: completion::CompletionResponse<CompletionResponse> =
653 mira_response.try_into().unwrap();
654
655 assert_eq!(
656 completion_response.choice.first(),
657 completion::AssistantContent::text("Test response")
658 );
659 }
660}