1use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use axum::extract::State;
11use axum::http::StatusCode;
12use axum::response::{IntoResponse, Response};
13use axum::Json;
14use embacle::types::{
15 ChatMessage, ChatRequest, ErrorKind, LlmCapabilities, ResponseFormat, RunnerError,
16};
17use embacle::FunctionDeclaration;
18use tracing::{debug, error, warn};
19
20use crate::openai_types::{
21 ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, Choice, ContentPart,
22 ErrorResponse, MessageContent, ModelField, MultiplexProviderResult, MultiplexResponse,
23 ResponseFormatRequest, ResponseMessage, StopField, ToolCall, ToolCallFunction, ToolChoice,
24 Usage,
25};
26use crate::provider_resolver::resolve_model;
27use crate::runner::multiplex::{MultiplexEngine, MultiplexParams};
28use crate::state::SharedState;
29use crate::streaming;
30
31const MAX_TEMPERATURE: f32 = 2.0;
33
34pub async fn handle(
39 State(state): State<SharedState>,
40 Json(request): Json<ChatCompletionRequest>,
41) -> Response {
42 if let Some(temp) = request.temperature {
43 if !(0.0..=MAX_TEMPERATURE).contains(&temp) {
44 return error_response(
45 StatusCode::BAD_REQUEST,
46 &format!("temperature must be between 0.0 and {MAX_TEMPERATURE}"),
47 );
48 }
49 }
50 if let Some(max) = request.max_tokens {
51 if max == 0 {
52 return error_response(StatusCode::BAD_REQUEST, "max_tokens must be greater than 0");
53 }
54 }
55 if let Some(top_p) = request.top_p {
56 if !(0.0..=1.0).contains(&top_p) {
57 return error_response(StatusCode::BAD_REQUEST, "top_p must be between 0.0 and 1.0");
58 }
59 }
60 if let Some(ref stop) = request.stop {
61 if stop.len() > 4 {
62 return error_response(
63 StatusCode::BAD_REQUEST,
64 "stop must have at most 4 sequences",
65 );
66 }
67 }
68
69 match request.model {
70 ModelField::Multiple(ref models) if models.len() > 1 => {
71 handle_multiplex(&state, &request, models).await
72 }
73 ModelField::Multiple(ref models) if models.len() == 1 => {
74 handle_single(&state, &request, &models[0]).await
75 }
76 ModelField::Multiple(_) => {
77 error_response(StatusCode::BAD_REQUEST, "Model array must not be empty")
78 }
79 ModelField::Single(ref model) => handle_single(&state, &request, model).await,
80 }
81}
82
83async fn handle_single(
85 state: &SharedState,
86 request: &ChatCompletionRequest,
87 model_str: &str,
88) -> Response {
89 let has_tools = request
90 .tools
91 .as_ref()
92 .is_some_and(|t| !t.is_empty() && !is_tool_choice_none(request.tool_choice.as_ref()));
93
94 let state_guard = state.read().await;
95 let resolved = resolve_model(model_str, state_guard.active_provider());
96 debug!(
97 provider = %resolved.runner_type,
98 model = ?resolved.model,
99 stream = request.stream,
100 has_tools,
101 "Dispatching completion"
102 );
103
104 let runner = match state_guard.get_runner(resolved.runner_type).await {
105 Ok(r) => r,
106 Err(e) => return runner_error_to_response(&e),
107 };
108 drop(state_guard);
109
110 let strict = request.strict_capabilities.unwrap_or_else(|| {
111 std::env::var("EMBACLE_STRICT_CAPS")
112 .map(|v| v == "true" || v == "1")
113 .unwrap_or(false)
114 });
115
116 let mut messages = convert_messages(&request.messages);
117
118 if has_tools {
120 let declarations = tools_to_declarations(request.tools.as_deref().unwrap_or_default());
121 let catalog = embacle::generate_tool_catalog(&declarations);
122
123 if runner
124 .capabilities()
125 .contains(LlmCapabilities::SYSTEM_MESSAGES)
126 {
127 embacle::inject_tool_catalog(&mut messages, &catalog);
128 } else {
129 inject_tool_catalog_as_user_message(&mut messages, &catalog);
130 }
131 }
132
133 let mut chat_request = ChatRequest::new(messages);
134 chat_request.model = resolved.model;
135 chat_request.temperature = request.temperature;
136 chat_request.max_tokens = request.max_tokens;
137 chat_request.top_p = request.top_p;
138 chat_request.stop = request.stop.as_ref().map(StopField::to_bounded_vec);
139 chat_request.response_format = request.response_format.as_ref().map(server_format_to_core);
140 chat_request.tools = request
141 .tools
142 .as_ref()
143 .map(|tools| tools.iter().map(server_tool_to_core).collect());
144 chat_request.tool_choice = request.tool_choice.as_ref().map(server_choice_to_core);
145
146 let warnings = match embacle::validate_capabilities(
147 runner.name(),
148 runner.capabilities(),
149 &chat_request,
150 strict,
151 ) {
152 Ok(w) => w,
153 Err(e) => return runner_error_to_response(&e),
154 };
155 let warnings_for_response = if warnings.is_empty() {
156 None
157 } else {
158 Some(warnings)
159 };
160
161 let supports_streaming = runner.capabilities().contains(LlmCapabilities::STREAMING);
162
163 dispatch_completion(
164 runner.as_ref(),
165 resolved.runner_type,
166 chat_request,
167 request.stream,
168 has_tools,
169 supports_streaming,
170 warnings_for_response,
171 )
172 .await
173}
174
175fn wants_json(format: Option<&ResponseFormat>) -> bool {
177 matches!(
178 format,
179 Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
180 )
181}
182
183fn strip_json_fences(content: String, json_mode: bool) -> String {
189 if json_mode {
190 embacle::extract_json_from_response(&content)
191 } else {
192 content
193 }
194}
195
196async fn dispatch_completion(
204 runner: &dyn embacle::types::LlmProvider,
205 runner_type: embacle::config::CliRunnerType,
206 mut chat_request: ChatRequest,
207 stream: bool,
208 has_tools: bool,
209 supports_streaming: bool,
210 warnings: Option<Vec<String>>,
211) -> Response {
212 let json_mode = wants_json(chat_request.response_format.as_ref());
213
214 if stream && (has_tools || !supports_streaming) {
215 if has_tools {
217 debug!("Downgrading stream+tools to non-streaming complete");
218 } else {
219 debug!(
220 provider = runner.name(),
221 "Provider does not support streaming; downgrading to non-streaming complete"
222 );
223 }
224 match runner.complete(&chat_request).await {
225 Ok(response) => {
226 let model_name = format!("{runner_type}:{}", response.model);
227 let content = strip_json_fences(response.content, json_mode);
228 let (message, finish_reason) = build_response_message(
229 has_tools,
230 content,
231 response.finish_reason,
232 response.tool_calls.as_ref(),
233 );
234 let reason = finish_reason.as_deref().unwrap_or("stop");
235 streaming::sse_single_response(message, reason, &model_name)
236 }
237 Err(e) => runner_error_to_response(&e),
238 }
239 } else if stream {
240 chat_request.stream = true;
241 match runner.complete_stream(&chat_request).await {
242 Ok(s) => {
243 let model_name = format!("{runner_type}:{}", runner.default_model());
244 if json_mode {
245 streaming::sse_response_strip_fences(s, &model_name)
246 } else {
247 streaming::sse_response(s, &model_name)
248 }
249 }
250 Err(e) => runner_error_to_response(&e),
251 }
252 } else {
253 match runner.complete(&chat_request).await {
254 Ok(response) => {
255 let model_name = format!("{runner_type}:{}", response.model);
256 let usage = response.usage.map(|u| Usage {
257 prompt: u.prompt_tokens,
258 completion: u.completion_tokens,
259 total: u.total_tokens,
260 });
261
262 let content = strip_json_fences(response.content, json_mode);
263 let (message, finish_reason) = build_response_message(
264 has_tools,
265 content,
266 response.finish_reason,
267 response.tool_calls.as_ref(),
268 );
269
270 let resp = ChatCompletionResponse {
271 id: generate_id(),
272 object: "chat.completion",
273 created: unix_timestamp(),
274 model: model_name,
275 choices: vec![Choice {
276 index: 0,
277 message,
278 finish_reason,
279 }],
280 usage,
281 warnings,
282 };
283
284 (StatusCode::OK, Json(resp)).into_response()
285 }
286 Err(e) => runner_error_to_response(&e),
287 }
288 }
289}
290
291async fn handle_multiplex(
293 state: &SharedState,
294 request: &ChatCompletionRequest,
295 models: &[String],
296) -> Response {
297 if request.stream {
298 return error_response(
299 StatusCode::BAD_REQUEST,
300 "Streaming is not supported for multiplex requests",
301 );
302 }
303
304 let strict = request.strict_capabilities.unwrap_or_else(|| {
305 std::env::var("EMBACLE_STRICT_CAPS")
306 .map(|v| v == "true" || v == "1")
307 .unwrap_or(false)
308 });
309
310 let state_guard = state.read().await;
311 let default_provider = state_guard.active_provider();
312 let resolved: Vec<_> = models
313 .iter()
314 .map(|m| resolve_model(m, default_provider))
315 .collect();
316
317 let providers: Vec<_> = resolved.iter().map(|r| r.runner_type).collect();
318 let messages = convert_messages(&request.messages);
319
320 let mut validation_request = ChatRequest::new(messages.clone());
322 validation_request.temperature = request.temperature;
323 validation_request.max_tokens = request.max_tokens;
324 validation_request.top_p = request.top_p;
325 validation_request.stop = request.stop.as_ref().map(StopField::to_bounded_vec);
326 validation_request.response_format =
327 request.response_format.as_ref().map(server_format_to_core);
328
329 for &provider_type in &providers {
330 let runner = match state_guard.get_runner(provider_type).await {
331 Ok(r) => r,
332 Err(e) => return runner_error_to_response(&e),
333 };
334 match embacle::validate_capabilities(
335 runner.name(),
336 runner.capabilities(),
337 &validation_request,
338 strict,
339 ) {
340 Ok(w) => {
341 for warning in &w {
342 warn!(provider = runner.name(), warning = %warning, "Capability warning");
343 }
344 }
345 Err(e) => return runner_error_to_response(&e),
346 }
347 }
348
349 drop(state_guard);
350 let engine = MultiplexEngine::new(state);
351 let params = MultiplexParams {
352 temperature: request.temperature,
353 max_tokens: request.max_tokens,
354 top_p: request.top_p,
355 stop: request.stop.as_ref().map(StopField::to_bounded_vec),
356 response_format: request.response_format.as_ref().map(server_format_to_core),
357 };
358 match engine.execute(&messages, &providers, ¶ms).await {
359 Ok(result) => {
360 let results = result
361 .responses
362 .into_iter()
363 .map(|r| MultiplexProviderResult {
364 provider: r.provider,
365 model: r.model,
366 content: r.content,
367 error: r.error,
368 duration_ms: r.duration_ms,
369 })
370 .collect();
371
372 let resp = MultiplexResponse {
373 id: generate_id(),
374 object: "chat.completion.multiplex",
375 created: unix_timestamp(),
376 results,
377 summary: result.summary,
378 };
379
380 (StatusCode::OK, Json(resp)).into_response()
381 }
382 Err(e) => runner_error_to_response(&e),
383 }
384}
385
386fn build_response_message(
389 has_tools: bool,
390 content: String,
391 finish_reason: Option<String>,
392 native_tool_calls: Option<&Vec<embacle::ToolCallRequest>>,
393) -> (ResponseMessage, Option<String>) {
394 if let Some(calls) = native_tool_calls {
396 if !calls.is_empty() {
397 let tool_calls: Vec<ToolCall> = calls
398 .iter()
399 .enumerate()
400 .map(|(i, tc)| ToolCall {
401 index: i,
402 id: tc.id.clone(),
403 tool_type: "function".to_owned(),
404 function: ToolCallFunction {
405 name: tc.function_name.clone(),
406 arguments: serde_json::to_string(&tc.arguments)
407 .unwrap_or_else(|_| "{}".to_owned()),
408 },
409 })
410 .collect();
411 let text_content = if content.is_empty() {
412 None
413 } else {
414 Some(content)
415 };
416 return (
417 ResponseMessage {
418 role: "assistant",
419 content: text_content,
420 tool_calls: Some(tool_calls),
421 },
422 Some("tool_calls".to_owned()),
423 );
424 }
425 }
426
427 if has_tools {
429 let parsed_calls = embacle::parse_tool_call_blocks(&content);
430 if parsed_calls.is_empty() {
431 (
432 ResponseMessage {
433 role: "assistant",
434 content: Some(content),
435 tool_calls: None,
436 },
437 finish_reason.or_else(|| Some("stop".to_owned())),
438 )
439 } else {
440 let remaining_text = embacle::strip_tool_call_blocks(&content);
441 let text_content = if remaining_text.is_empty() {
442 None
443 } else {
444 Some(remaining_text)
445 };
446 let tool_calls: Vec<ToolCall> = parsed_calls
447 .iter()
448 .enumerate()
449 .map(|(i, fc)| ToolCall {
450 index: i,
451 id: generate_tool_call_id(&fc.name, i),
452 tool_type: "function".to_owned(),
453 function: ToolCallFunction {
454 name: fc.name.clone(),
455 arguments: serde_json::to_string(&fc.args)
456 .unwrap_or_else(|_| "{}".to_owned()),
457 },
458 })
459 .collect();
460 (
461 ResponseMessage {
462 role: "assistant",
463 content: text_content,
464 tool_calls: Some(tool_calls),
465 },
466 Some("tool_calls".to_owned()),
467 )
468 }
469 } else {
470 (
471 ResponseMessage {
472 role: "assistant",
473 content: Some(content),
474 tool_calls: None,
475 },
476 finish_reason.or_else(|| Some("stop".to_owned())),
477 )
478 }
479}
480
481fn content_as_text(content: Option<&MessageContent>) -> String {
483 content.map(MessageContent::as_text).unwrap_or_default()
484}
485
486fn parse_data_uri(url: &str) -> Option<embacle::ImagePart> {
490 let rest = url.strip_prefix("data:")?;
491 let (mime_type, data) = rest.split_once(";base64,")?;
492 embacle::ImagePart::new(data, mime_type).ok()
493}
494
495fn extract_images(content: Option<&MessageContent>) -> Option<Vec<embacle::ImagePart>> {
497 let Some(MessageContent::Parts(parts)) = content else {
498 return None;
499 };
500
501 let images: Vec<embacle::ImagePart> = parts
502 .iter()
503 .filter_map(|p| match p {
504 ContentPart::ImageUrl { image_url } => parse_data_uri(&image_url.url),
505 ContentPart::Text { .. } => None,
506 })
507 .collect();
508
509 if images.is_empty() {
510 None
511 } else {
512 Some(images)
513 }
514}
515
516fn convert_messages(messages: &[ChatCompletionMessage]) -> Vec<ChatMessage> {
524 let mut result = Vec::with_capacity(messages.len());
525 let mut i = 0;
526
527 while i < messages.len() {
528 let m = &messages[i];
529 match m.role.as_str() {
530 "system" => {
531 result.push(ChatMessage::system(content_as_text(m.content.as_ref())));
532 i += 1;
533 }
534 "user" => {
535 let text = content_as_text(m.content.as_ref());
536 let images = extract_images(m.content.as_ref());
537 if let Some(imgs) = images {
538 result.push(ChatMessage::user_with_images(text, imgs));
539 } else {
540 result.push(ChatMessage::user(text));
541 }
542 i += 1;
543 }
544 "assistant" => {
545 if let Some(ref tool_calls) = m.tool_calls {
546 let mut text = content_as_text(m.content.as_ref());
548 for tc in tool_calls {
549 text.push_str("\n<tool_call>\n");
550 let payload = serde_json::json!({
551 "name": tc.function.name,
552 "arguments": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
553 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()))
554 });
555 text.push_str(
556 &serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_owned()),
557 );
558 text.push_str("\n</tool_call>");
559 }
560 result.push(ChatMessage::assistant(text));
561 } else {
562 result.push(ChatMessage::assistant(content_as_text(m.content.as_ref())));
563 }
564 i += 1;
565 }
566 "tool" => {
567 let mut tool_responses = Vec::new();
569 while i < messages.len() && messages[i].role == "tool" {
570 let tool_msg = &messages[i];
571 let name = tool_msg.name.as_deref().unwrap_or("unknown");
572 let content_text = content_as_text(tool_msg.content.as_ref());
573 let response_value: serde_json::Value = if content_text.is_empty() {
574 serde_json::Value::Null
575 } else {
576 serde_json::from_str(&content_text)
577 .unwrap_or(serde_json::Value::String(content_text))
578 };
579 tool_responses.push(embacle::FunctionResponse {
580 name: name.to_owned(),
581 response: response_value,
582 });
583 i += 1;
584 }
585 let text = embacle::format_tool_results_as_text(&tool_responses);
586 result.push(ChatMessage::user(text));
587 }
588 other => {
589 warn!(role = other, "Unknown message role, mapping to user");
590 result.push(ChatMessage::user(content_as_text(m.content.as_ref())));
591 i += 1;
592 }
593 }
594 }
595
596 result
597}
598
599fn server_tool_to_core(tool: &crate::openai_types::ToolDefinition) -> embacle::ToolDefinition {
601 embacle::ToolDefinition {
602 name: tool.function.name.clone(),
603 description: tool.function.description.clone().unwrap_or_default(),
604 parameters: tool.function.parameters.clone(),
605 }
606}
607
608fn server_choice_to_core(choice: &ToolChoice) -> embacle::ToolChoice {
610 match choice {
611 ToolChoice::Mode(m) => match m.as_str() {
612 "none" => embacle::ToolChoice::None,
613 "required" => embacle::ToolChoice::Required,
614 _ => embacle::ToolChoice::Auto,
615 },
616 ToolChoice::Specific(s) => embacle::ToolChoice::Specific {
617 name: s.function.name.clone(),
618 },
619 }
620}
621
622fn server_format_to_core(format: &ResponseFormatRequest) -> embacle::ResponseFormat {
624 match format {
625 ResponseFormatRequest::Text => embacle::ResponseFormat::Text,
626 ResponseFormatRequest::JsonObject => embacle::ResponseFormat::JsonObject,
627 ResponseFormatRequest::JsonSchema { json_schema } => embacle::ResponseFormat::JsonSchema {
628 name: json_schema.name.clone(),
629 schema: json_schema.schema.clone(),
630 },
631 }
632}
633
634fn tools_to_declarations(
636 tools: &[crate::openai_types::ToolDefinition],
637) -> Vec<FunctionDeclaration> {
638 tools
639 .iter()
640 .map(|t| FunctionDeclaration {
641 name: t.function.name.clone(),
642 description: t.function.description.clone().unwrap_or_default(),
643 parameters: t.function.parameters.clone(),
644 })
645 .collect()
646}
647
648fn inject_tool_catalog_as_user_message(messages: &mut [ChatMessage], catalog: &str) {
654 if let Some(last_user) = messages
655 .iter_mut()
656 .rev()
657 .find(|m| m.role == embacle::types::MessageRole::User)
658 {
659 let augmented = format!("{catalog}\n\n{}", last_user.content);
660 *last_user = ChatMessage::user(augmented);
661 } else {
662 warn!("No user message found for tool catalog injection");
665 }
666}
667
668fn is_tool_choice_none(tool_choice: Option<&ToolChoice>) -> bool {
670 matches!(tool_choice, Some(ToolChoice::Mode(ref m)) if m == "none")
671}
672
673fn generate_tool_call_id(name: &str, index: usize) -> String {
675 format!("call_{name}_{index}")
676}
677
678fn runner_error_to_response(err: &RunnerError) -> Response {
680 let (status, error_type) = match err.kind {
681 ErrorKind::BinaryNotFound => (StatusCode::SERVICE_UNAVAILABLE, "provider_not_available"),
682 ErrorKind::AuthFailure => (StatusCode::UNAUTHORIZED, "authentication_error"),
683 ErrorKind::Timeout => (StatusCode::GATEWAY_TIMEOUT, "timeout_error"),
684 ErrorKind::ExternalService => (StatusCode::BAD_GATEWAY, "external_service_error"),
685 ErrorKind::Config => (StatusCode::BAD_REQUEST, "invalid_request_error"),
686 ErrorKind::Guardrail => (StatusCode::BAD_REQUEST, "guardrail_error"),
687 ErrorKind::Internal => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
688 };
689
690 error!(kind = ?err.kind, message = %err.message, "Runner error");
691 let body = ErrorResponse::new(error_type, &err.message);
692 (status, Json(body)).into_response()
693}
694
695fn error_response(status: StatusCode, message: &str) -> Response {
697 let body = ErrorResponse::new("invalid_request_error", message);
698 (status, Json(body)).into_response()
699}
700
701static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
703
704pub fn generate_id() -> String {
709 let ts = unix_timestamp();
710 let seq = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
711 format!("chatcmpl-{ts:x}{seq:08x}")
712}
713
714pub fn unix_timestamp() -> u64 {
716 SystemTime::now()
717 .duration_since(UNIX_EPOCH)
718 .map(|d| d.as_secs())
719 .unwrap_or(0)
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use crate::openai_types::{
726 ContentPart, FunctionObject, ImageUrlDetail, ToolCall, ToolCallFunction, ToolDefinition,
727 };
728 use embacle::types::MessageRole;
729
730 fn text_msg(role: &str, content: Option<&str>) -> ChatCompletionMessage {
732 ChatCompletionMessage {
733 role: role.to_owned(),
734 content: content.map(|c| MessageContent::Text(c.to_owned())),
735 tool_calls: None,
736 tool_call_id: None,
737 name: None,
738 }
739 }
740
741 #[test]
742 fn convert_messages_maps_roles() {
743 let openai_msgs = vec![
744 text_msg("system", Some("You are helpful")),
745 text_msg("user", Some("Hello")),
746 text_msg("assistant", Some("Hi there")),
747 ];
748
749 let messages = convert_messages(&openai_msgs);
750 assert_eq!(messages.len(), 3);
751 assert_eq!(messages[0].role, MessageRole::System);
752 assert_eq!(messages[1].role, MessageRole::User);
753 assert_eq!(messages[2].role, MessageRole::Assistant);
754 }
755
756 #[test]
757 fn convert_unknown_role_defaults_to_user() {
758 let openai_msgs = vec![text_msg("function", Some("result"))];
759
760 let messages = convert_messages(&openai_msgs);
761 assert_eq!(messages[0].role, MessageRole::User);
762 }
763
764 #[test]
765 fn convert_assistant_with_tool_calls() {
766 let openai_msgs = vec![ChatCompletionMessage {
767 role: "assistant".to_owned(),
768 content: None,
769 tool_calls: Some(vec![ToolCall {
770 index: 0,
771 id: "call_1".to_owned(),
772 tool_type: "function".to_owned(),
773 function: ToolCallFunction {
774 name: "get_weather".to_owned(),
775 arguments: r#"{"city":"Paris"}"#.to_owned(),
776 },
777 }]),
778 tool_call_id: None,
779 name: None,
780 }];
781
782 let messages = convert_messages(&openai_msgs);
783 assert_eq!(messages.len(), 1);
784 assert_eq!(messages[0].role, MessageRole::Assistant);
785 assert!(messages[0].content.contains("<tool_call>"));
786 assert!(messages[0].content.contains("get_weather"));
787 assert!(messages[0].content.contains("</tool_call>"));
788 }
789
790 #[test]
791 fn convert_tool_messages_to_user() {
792 let openai_msgs = vec![
793 ChatCompletionMessage {
794 role: "tool".to_owned(),
795 content: Some(MessageContent::Text(r#"{"temp":72}"#.to_owned())),
796 tool_calls: None,
797 tool_call_id: Some("call_1".to_owned()),
798 name: Some("get_weather".to_owned()),
799 },
800 ChatCompletionMessage {
801 role: "tool".to_owned(),
802 content: Some(MessageContent::Text(r#"{"time":"14:30"}"#.to_owned())),
803 tool_calls: None,
804 tool_call_id: Some("call_2".to_owned()),
805 name: Some("get_time".to_owned()),
806 },
807 ];
808
809 let messages = convert_messages(&openai_msgs);
810 assert_eq!(messages.len(), 1);
812 assert_eq!(messages[0].role, MessageRole::User);
813 assert!(messages[0].content.contains("tool_result"));
814 assert!(messages[0].content.contains("get_weather"));
815 assert!(messages[0].content.contains("get_time"));
816 }
817
818 #[test]
819 fn convert_messages_none_content() {
820 let openai_msgs = vec![text_msg("user", None)];
821
822 let messages = convert_messages(&openai_msgs);
823 assert_eq!(messages[0].content, "");
824 }
825
826 #[test]
827 fn convert_multipart_user_message_extracts_images() {
828 let openai_msgs = vec![ChatCompletionMessage {
829 role: "user".to_owned(),
830 content: Some(MessageContent::Parts(vec![
831 ContentPart::Text {
832 text: "What is this?".to_owned(),
833 },
834 ContentPart::ImageUrl {
835 image_url: ImageUrlDetail {
836 url: "data:image/png;base64,aGVsbG8=".to_owned(),
837 },
838 },
839 ])),
840 tool_calls: None,
841 tool_call_id: None,
842 name: None,
843 }];
844
845 let messages = convert_messages(&openai_msgs);
846 assert_eq!(messages.len(), 1);
847 assert_eq!(messages[0].content, "What is this?");
848 let images = messages[0].images.as_ref().expect("images present");
849 assert_eq!(images.len(), 1);
850 assert_eq!(images[0].mime_type, "image/png");
851 assert_eq!(images[0].data, "aGVsbG8=");
852 }
853
854 #[test]
855 fn parse_data_uri_valid() {
856 let img = parse_data_uri("data:image/jpeg;base64,AAAA").expect("should parse");
857 assert_eq!(img.mime_type, "image/jpeg");
858 assert_eq!(img.data, "AAAA");
859 }
860
861 #[test]
862 fn parse_data_uri_invalid_format() {
863 assert!(parse_data_uri("https://example.com/image.png").is_none());
864 assert!(parse_data_uri("data:text/plain;base64,abc").is_none());
865 assert!(parse_data_uri("data:image/png;abc").is_none());
866 }
867
868 #[test]
869 fn convert_plain_string_content_backward_compat() {
870 let openai_msgs = vec![text_msg("user", Some("hello"))];
871 let messages = convert_messages(&openai_msgs);
872 assert_eq!(messages[0].content, "hello");
873 assert!(messages[0].images.is_none());
874 }
875
876 #[test]
877 fn tools_to_declarations_converts() {
878 let tools = vec![ToolDefinition {
879 tool_type: "function".to_owned(),
880 function: FunctionObject {
881 name: "search".to_owned(),
882 description: Some("Search the web".to_owned()),
883 parameters: Some(serde_json::json!({
884 "type": "object",
885 "properties": {"q": {"type": "string"}},
886 "required": ["q"]
887 })),
888 },
889 }];
890
891 let decls = tools_to_declarations(&tools);
892 assert_eq!(decls.len(), 1);
893 assert_eq!(decls[0].name, "search");
894 assert_eq!(decls[0].description, "Search the web");
895 assert!(decls[0].parameters.is_some());
896 }
897
898 #[test]
899 fn tool_choice_none_detection() {
900 let none_choice = ToolChoice::Mode("none".to_owned());
901 assert!(is_tool_choice_none(Some(&none_choice)));
902 let auto_choice = ToolChoice::Mode("auto".to_owned());
903 assert!(!is_tool_choice_none(Some(&auto_choice)));
904 assert!(!is_tool_choice_none(None));
905 }
906
907 #[test]
908 fn content_as_text_none() {
909 assert_eq!(content_as_text(None), "");
910 }
911
912 #[test]
913 fn content_as_text_plain() {
914 let content = MessageContent::Text("hello".to_owned());
915 assert_eq!(content_as_text(Some(&content)), "hello");
916 }
917
918 #[test]
919 fn generate_tool_call_id_format() {
920 let id = generate_tool_call_id("get_weather", 0);
921 assert_eq!(id, "call_get_weather_0");
922 }
923
924 #[test]
925 fn generate_id_has_prefix() {
926 let id = generate_id();
927 assert!(id.starts_with("chatcmpl-"));
928 }
929
930 #[test]
931 fn error_maps_binary_not_found_to_503() {
932 let err = RunnerError::binary_not_found("claude");
933 let (status, _) = match err.kind {
934 ErrorKind::BinaryNotFound => {
935 (StatusCode::SERVICE_UNAVAILABLE, "provider_not_available")
936 }
937 _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
938 };
939 assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
940 }
941
942 #[test]
943 fn error_maps_auth_to_401() {
944 let err = RunnerError::auth_failure("bad token");
945 let (status, _) = match err.kind {
946 ErrorKind::AuthFailure => (StatusCode::UNAUTHORIZED, "authentication_error"),
947 _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
948 };
949 assert_eq!(status, StatusCode::UNAUTHORIZED);
950 }
951
952 #[test]
953 fn error_maps_timeout_to_504() {
954 let err = RunnerError::timeout("too slow");
955 let (status, _) = match err.kind {
956 ErrorKind::Timeout => (StatusCode::GATEWAY_TIMEOUT, "timeout_error"),
957 _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
958 };
959 assert_eq!(status, StatusCode::GATEWAY_TIMEOUT);
960 }
961
962 #[test]
963 fn inject_tool_catalog_as_user_message_prepends_to_last_user() {
964 let mut messages = vec![
965 ChatMessage::user("First question"),
966 ChatMessage::assistant("Some answer"),
967 ChatMessage::user("What is the weather?"),
968 ];
969 let catalog = "## Available Tools\n- get_weather: Get the weather";
970
971 inject_tool_catalog_as_user_message(&mut messages, catalog);
972
973 assert_eq!(messages.len(), 3);
974 assert!(messages[2].content.starts_with("## Available Tools"));
975 assert!(messages[2].content.contains("What is the weather?"));
976 assert_eq!(messages[0].content, "First question");
978 }
979
980 #[test]
981 fn inject_tool_catalog_as_user_message_single_user() {
982 let mut messages = vec![
983 ChatMessage::system("You are helpful"),
984 ChatMessage::user("Hello"),
985 ];
986 let catalog = "## Tools\nsome tools";
987
988 inject_tool_catalog_as_user_message(&mut messages, catalog);
989
990 assert!(messages[1].content.starts_with("## Tools"));
991 assert!(messages[1].content.contains("Hello"));
992 }
993
994 #[test]
995 fn wants_json_matches_json_formats() {
996 use embacle::types::ResponseFormat;
997
998 assert!(!wants_json(None));
999 assert!(!wants_json(Some(&ResponseFormat::Text)));
1000 assert!(wants_json(Some(&ResponseFormat::JsonObject)));
1001 assert!(wants_json(Some(&ResponseFormat::JsonSchema {
1002 name: "test".to_owned(),
1003 schema: serde_json::json!({}),
1004 })));
1005 }
1006
1007 #[test]
1008 fn strip_json_fences_removes_markdown_wrapper() {
1009 let fenced = "```json\n{\"key\":\"value\"}\n```".to_owned();
1010 assert_eq!(strip_json_fences(fenced, true), "{\"key\":\"value\"}");
1011 }
1012
1013 #[test]
1014 fn strip_json_fences_passes_through_in_text_mode() {
1015 let fenced = "```json\n{\"key\":\"value\"}\n```".to_owned();
1016 assert_eq!(strip_json_fences(fenced.clone(), false), fenced);
1017 }
1018
1019 #[test]
1020 fn strip_json_fences_leaves_clean_json_unchanged() {
1021 let clean = "{\"key\":\"value\"}".to_owned();
1022 assert_eq!(strip_json_fences(clean.clone(), true), clean);
1023 }
1024}