Skip to main content

embacle_server/
completions.rs

1// ABOUTME: POST /v1/chat/completions handler for OpenAI-compatible chat completion
2// ABOUTME: Routes to single provider or multiplex, supports both streaming and non-streaming
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use axum::extract::State;
11use axum::http::StatusCode;
12use axum::response::{IntoResponse, Response};
13use axum::Json;
14use embacle::types::{
15    ChatMessage, ChatRequest, ErrorKind, LlmCapabilities, ResponseFormat, RunnerError,
16};
17use embacle::FunctionDeclaration;
18use tracing::{debug, error, warn};
19
20use crate::openai_types::{
21    ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, Choice, ContentPart,
22    ErrorResponse, MessageContent, ModelField, MultiplexProviderResult, MultiplexResponse,
23    ResponseFormatRequest, ResponseMessage, StopField, ToolCall, ToolCallFunction, ToolChoice,
24    Usage,
25};
26use crate::provider_resolver::resolve_model;
27use crate::runner::multiplex::{MultiplexEngine, MultiplexParams};
28use crate::state::SharedState;
29use crate::streaming;
30
31/// OpenAI-specified upper bound for temperature
32const MAX_TEMPERATURE: f32 = 2.0;
33
34/// Handle POST /v1/chat/completions
35///
36/// Dispatches to single-provider or multiplex mode based on the model field.
37/// Supports both streaming (SSE) and non-streaming (JSON) responses.
38pub async fn handle(
39    State(state): State<SharedState>,
40    Json(request): Json<ChatCompletionRequest>,
41) -> Response {
42    if let Some(temp) = request.temperature {
43        if !(0.0..=MAX_TEMPERATURE).contains(&temp) {
44            return error_response(
45                StatusCode::BAD_REQUEST,
46                &format!("temperature must be between 0.0 and {MAX_TEMPERATURE}"),
47            );
48        }
49    }
50    if let Some(max) = request.max_tokens {
51        if max == 0 {
52            return error_response(StatusCode::BAD_REQUEST, "max_tokens must be greater than 0");
53        }
54    }
55    if let Some(top_p) = request.top_p {
56        if !(0.0..=1.0).contains(&top_p) {
57            return error_response(StatusCode::BAD_REQUEST, "top_p must be between 0.0 and 1.0");
58        }
59    }
60    if let Some(ref stop) = request.stop {
61        if stop.len() > 4 {
62            return error_response(
63                StatusCode::BAD_REQUEST,
64                "stop must have at most 4 sequences",
65            );
66        }
67    }
68
69    match request.model {
70        ModelField::Multiple(ref models) if models.len() > 1 => {
71            handle_multiplex(&state, &request, models).await
72        }
73        ModelField::Multiple(ref models) if models.len() == 1 => {
74            handle_single(&state, &request, &models[0]).await
75        }
76        ModelField::Multiple(_) => {
77            error_response(StatusCode::BAD_REQUEST, "Model array must not be empty")
78        }
79        ModelField::Single(ref model) => handle_single(&state, &request, model).await,
80    }
81}
82
83/// Handle a single-provider request (standard case)
84async fn handle_single(
85    state: &SharedState,
86    request: &ChatCompletionRequest,
87    model_str: &str,
88) -> Response {
89    let has_tools = request
90        .tools
91        .as_ref()
92        .is_some_and(|t| !t.is_empty() && !is_tool_choice_none(request.tool_choice.as_ref()));
93
94    let state_guard = state.read().await;
95    let resolved = resolve_model(model_str, state_guard.active_provider());
96    debug!(
97        provider = %resolved.runner_type,
98        model = ?resolved.model,
99        stream = request.stream,
100        has_tools,
101        "Dispatching completion"
102    );
103
104    let runner = match state_guard.get_runner(resolved.runner_type).await {
105        Ok(r) => r,
106        Err(e) => return runner_error_to_response(&e),
107    };
108    drop(state_guard);
109
110    let strict = request.strict_capabilities.unwrap_or_else(|| {
111        std::env::var("EMBACLE_STRICT_CAPS")
112            .map(|v| v == "true" || v == "1")
113            .unwrap_or(false)
114    });
115
116    let mut messages = convert_messages(&request.messages);
117
118    // Inject tool catalog using the most effective strategy for this provider
119    if has_tools {
120        let declarations = tools_to_declarations(request.tools.as_deref().unwrap_or_default());
121        let catalog = embacle::generate_tool_catalog(&declarations);
122
123        if runner
124            .capabilities()
125            .contains(LlmCapabilities::SYSTEM_MESSAGES)
126        {
127            embacle::inject_tool_catalog(&mut messages, &catalog);
128        } else {
129            inject_tool_catalog_as_user_message(&mut messages, &catalog);
130        }
131    }
132
133    let mut chat_request = ChatRequest::new(messages);
134    chat_request.model = resolved.model;
135    chat_request.temperature = request.temperature;
136    chat_request.max_tokens = request.max_tokens;
137    chat_request.top_p = request.top_p;
138    chat_request.stop = request.stop.as_ref().map(StopField::to_bounded_vec);
139    chat_request.response_format = request.response_format.as_ref().map(server_format_to_core);
140    chat_request.tools = request
141        .tools
142        .as_ref()
143        .map(|tools| tools.iter().map(server_tool_to_core).collect());
144    chat_request.tool_choice = request.tool_choice.as_ref().map(server_choice_to_core);
145
146    let warnings = match embacle::validate_capabilities(
147        runner.name(),
148        runner.capabilities(),
149        &chat_request,
150        strict,
151    ) {
152        Ok(w) => w,
153        Err(e) => return runner_error_to_response(&e),
154    };
155    let warnings_for_response = if warnings.is_empty() {
156        None
157    } else {
158        Some(warnings)
159    };
160
161    let supports_streaming = runner.capabilities().contains(LlmCapabilities::STREAMING);
162
163    dispatch_completion(
164        runner.as_ref(),
165        resolved.runner_type,
166        chat_request,
167        request.stream,
168        has_tools,
169        supports_streaming,
170        warnings_for_response,
171    )
172    .await
173}
174
175/// Check if the response format requests JSON output
176fn wants_json(format: Option<&ResponseFormat>) -> bool {
177    matches!(
178        format,
179        Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
180    )
181}
182
183/// Strip markdown code fences from content when JSON output is requested.
184///
185/// CLI runners often wrap JSON in `` ```json ... ``` `` fences. When the client
186/// has requested `response_format: json_object` or `json_schema`, extract the
187/// raw JSON so the response is directly parseable.
188fn strip_json_fences(content: String, json_mode: bool) -> String {
189    if json_mode {
190        embacle::extract_json_from_response(&content)
191    } else {
192        content
193    }
194}
195
196/// Dispatch the completion request to the appropriate execution path
197///
198/// Routes between four modes:
199/// 1. Streaming with tools: downgrade to `complete()`, emit as SSE
200/// 2. Streaming without provider support: downgrade to `complete()`, emit as SSE
201/// 3. Pure streaming: use `complete_stream()`
202/// 4. Non-streaming: use `complete()`, return JSON
203async fn dispatch_completion(
204    runner: &dyn embacle::types::LlmProvider,
205    runner_type: embacle::config::CliRunnerType,
206    mut chat_request: ChatRequest,
207    stream: bool,
208    has_tools: bool,
209    supports_streaming: bool,
210    warnings: Option<Vec<String>>,
211) -> Response {
212    let json_mode = wants_json(chat_request.response_format.as_ref());
213
214    if stream && (has_tools || !supports_streaming) {
215        // Downgrade to non-streaming complete(), emit result as SSE
216        if has_tools {
217            debug!("Downgrading stream+tools to non-streaming complete");
218        } else {
219            debug!(
220                provider = runner.name(),
221                "Provider does not support streaming; downgrading to non-streaming complete"
222            );
223        }
224        match runner.complete(&chat_request).await {
225            Ok(response) => {
226                let model_name = format!("{runner_type}:{}", response.model);
227                let content = strip_json_fences(response.content, json_mode);
228                let (message, finish_reason) = build_response_message(
229                    has_tools,
230                    content,
231                    response.finish_reason,
232                    response.tool_calls.as_ref(),
233                );
234                let reason = finish_reason.as_deref().unwrap_or("stop");
235                streaming::sse_single_response(message, reason, &model_name)
236            }
237            Err(e) => runner_error_to_response(&e),
238        }
239    } else if stream {
240        chat_request.stream = true;
241        match runner.complete_stream(&chat_request).await {
242            Ok(s) => {
243                let model_name = format!("{runner_type}:{}", runner.default_model());
244                if json_mode {
245                    streaming::sse_response_strip_fences(s, &model_name)
246                } else {
247                    streaming::sse_response(s, &model_name)
248                }
249            }
250            Err(e) => runner_error_to_response(&e),
251        }
252    } else {
253        match runner.complete(&chat_request).await {
254            Ok(response) => {
255                let model_name = format!("{runner_type}:{}", response.model);
256                let usage = response.usage.map(|u| Usage {
257                    prompt: u.prompt_tokens,
258                    completion: u.completion_tokens,
259                    total: u.total_tokens,
260                });
261
262                let content = strip_json_fences(response.content, json_mode);
263                let (message, finish_reason) = build_response_message(
264                    has_tools,
265                    content,
266                    response.finish_reason,
267                    response.tool_calls.as_ref(),
268                );
269
270                let resp = ChatCompletionResponse {
271                    id: generate_id(),
272                    object: "chat.completion",
273                    created: unix_timestamp(),
274                    model: model_name,
275                    choices: vec![Choice {
276                        index: 0,
277                        message,
278                        finish_reason,
279                    }],
280                    usage,
281                    warnings,
282                };
283
284                (StatusCode::OK, Json(resp)).into_response()
285            }
286            Err(e) => runner_error_to_response(&e),
287        }
288    }
289}
290
291/// Handle a multiplex request (multiple providers)
292async fn handle_multiplex(
293    state: &SharedState,
294    request: &ChatCompletionRequest,
295    models: &[String],
296) -> Response {
297    if request.stream {
298        return error_response(
299            StatusCode::BAD_REQUEST,
300            "Streaming is not supported for multiplex requests",
301        );
302    }
303
304    let strict = request.strict_capabilities.unwrap_or_else(|| {
305        std::env::var("EMBACLE_STRICT_CAPS")
306            .map(|v| v == "true" || v == "1")
307            .unwrap_or(false)
308    });
309
310    let state_guard = state.read().await;
311    let default_provider = state_guard.active_provider();
312    let resolved: Vec<_> = models
313        .iter()
314        .map(|m| resolve_model(m, default_provider))
315        .collect();
316
317    let providers: Vec<_> = resolved.iter().map(|r| r.runner_type).collect();
318    let messages = convert_messages(&request.messages);
319
320    // Build a temporary ChatRequest for capability validation
321    let mut validation_request = ChatRequest::new(messages.clone());
322    validation_request.temperature = request.temperature;
323    validation_request.max_tokens = request.max_tokens;
324    validation_request.top_p = request.top_p;
325    validation_request.stop = request.stop.as_ref().map(StopField::to_bounded_vec);
326    validation_request.response_format =
327        request.response_format.as_ref().map(server_format_to_core);
328
329    for &provider_type in &providers {
330        let runner = match state_guard.get_runner(provider_type).await {
331            Ok(r) => r,
332            Err(e) => return runner_error_to_response(&e),
333        };
334        match embacle::validate_capabilities(
335            runner.name(),
336            runner.capabilities(),
337            &validation_request,
338            strict,
339        ) {
340            Ok(w) => {
341                for warning in &w {
342                    warn!(provider = runner.name(), warning = %warning, "Capability warning");
343                }
344            }
345            Err(e) => return runner_error_to_response(&e),
346        }
347    }
348
349    drop(state_guard);
350    let engine = MultiplexEngine::new(state);
351    let params = MultiplexParams {
352        temperature: request.temperature,
353        max_tokens: request.max_tokens,
354        top_p: request.top_p,
355        stop: request.stop.as_ref().map(StopField::to_bounded_vec),
356        response_format: request.response_format.as_ref().map(server_format_to_core),
357    };
358    match engine.execute(&messages, &providers, &params).await {
359        Ok(result) => {
360            let results = result
361                .responses
362                .into_iter()
363                .map(|r| MultiplexProviderResult {
364                    provider: r.provider,
365                    model: r.model,
366                    content: r.content,
367                    error: r.error,
368                    duration_ms: r.duration_ms,
369                })
370                .collect();
371
372            let resp = MultiplexResponse {
373                id: generate_id(),
374                object: "chat.completion.multiplex",
375                created: unix_timestamp(),
376                results,
377                summary: result.summary,
378            };
379
380            (StatusCode::OK, Json(resp)).into_response()
381        }
382        Err(e) => runner_error_to_response(&e),
383    }
384}
385
386/// Build a `ResponseMessage` from LLM output, using native tool calls if available
387/// or falling back to XML parsing if tools were requested
388fn build_response_message(
389    has_tools: bool,
390    content: String,
391    finish_reason: Option<String>,
392    native_tool_calls: Option<&Vec<embacle::ToolCallRequest>>,
393) -> (ResponseMessage, Option<String>) {
394    // If the provider returned native tool calls, use them directly
395    if let Some(calls) = native_tool_calls {
396        if !calls.is_empty() {
397            let tool_calls: Vec<ToolCall> = calls
398                .iter()
399                .enumerate()
400                .map(|(i, tc)| ToolCall {
401                    index: i,
402                    id: tc.id.clone(),
403                    tool_type: "function".to_owned(),
404                    function: ToolCallFunction {
405                        name: tc.function_name.clone(),
406                        arguments: serde_json::to_string(&tc.arguments)
407                            .unwrap_or_else(|_| "{}".to_owned()),
408                    },
409                })
410                .collect();
411            let text_content = if content.is_empty() {
412                None
413            } else {
414                Some(content)
415            };
416            return (
417                ResponseMessage {
418                    role: "assistant",
419                    content: text_content,
420                    tool_calls: Some(tool_calls),
421                },
422                Some("tool_calls".to_owned()),
423            );
424        }
425    }
426
427    // Fall back to XML parsing for text-based tool simulation
428    if has_tools {
429        let parsed_calls = embacle::parse_tool_call_blocks(&content);
430        if parsed_calls.is_empty() {
431            (
432                ResponseMessage {
433                    role: "assistant",
434                    content: Some(content),
435                    tool_calls: None,
436                },
437                finish_reason.or_else(|| Some("stop".to_owned())),
438            )
439        } else {
440            let remaining_text = embacle::strip_tool_call_blocks(&content);
441            let text_content = if remaining_text.is_empty() {
442                None
443            } else {
444                Some(remaining_text)
445            };
446            let tool_calls: Vec<ToolCall> = parsed_calls
447                .iter()
448                .enumerate()
449                .map(|(i, fc)| ToolCall {
450                    index: i,
451                    id: generate_tool_call_id(&fc.name, i),
452                    tool_type: "function".to_owned(),
453                    function: ToolCallFunction {
454                        name: fc.name.clone(),
455                        arguments: serde_json::to_string(&fc.args)
456                            .unwrap_or_else(|_| "{}".to_owned()),
457                    },
458                })
459                .collect();
460            (
461                ResponseMessage {
462                    role: "assistant",
463                    content: text_content,
464                    tool_calls: Some(tool_calls),
465                },
466                Some("tool_calls".to_owned()),
467            )
468        }
469    } else {
470        (
471            ResponseMessage {
472                role: "assistant",
473                content: Some(content),
474                tool_calls: None,
475            },
476            finish_reason.or_else(|| Some("stop".to_owned())),
477        )
478    }
479}
480
481/// Extract text content from a `MessageContent`, returning an empty string for None
482fn content_as_text(content: Option<&MessageContent>) -> String {
483    content.map(MessageContent::as_text).unwrap_or_default()
484}
485
486/// Parse a `data:` URI into an `ImagePart`
487///
488/// Expected format: `data:<mime_type>;base64,<data>`
489fn parse_data_uri(url: &str) -> Option<embacle::ImagePart> {
490    let rest = url.strip_prefix("data:")?;
491    let (mime_type, data) = rest.split_once(";base64,")?;
492    embacle::ImagePart::new(data, mime_type).ok()
493}
494
495/// Extract images from a `MessageContent::Parts` variant
496fn extract_images(content: Option<&MessageContent>) -> Option<Vec<embacle::ImagePart>> {
497    let Some(MessageContent::Parts(parts)) = content else {
498        return None;
499    };
500
501    let images: Vec<embacle::ImagePart> = parts
502        .iter()
503        .filter_map(|p| match p {
504            ContentPart::ImageUrl { image_url } => parse_data_uri(&image_url.url),
505            ContentPart::Text { .. } => None,
506        })
507        .collect();
508
509    if images.is_empty() {
510        None
511    } else {
512        Some(images)
513    }
514}
515
516/// Convert `OpenAI` message format to embacle `ChatMessage`
517///
518/// Handles all `OpenAI` roles including "tool" messages and assistant messages
519/// with `tool_calls`. Tool messages are collected and formatted as `<tool_result>`
520/// blocks. Assistant messages with `tool_calls` are reconstructed as `<tool_call>` blocks.
521/// User messages with multipart content (text + images) are converted to `ChatMessage`
522/// with attached `ImagePart` entries.
523fn convert_messages(messages: &[ChatCompletionMessage]) -> Vec<ChatMessage> {
524    let mut result = Vec::with_capacity(messages.len());
525    let mut i = 0;
526
527    while i < messages.len() {
528        let m = &messages[i];
529        match m.role.as_str() {
530            "system" => {
531                result.push(ChatMessage::system(content_as_text(m.content.as_ref())));
532                i += 1;
533            }
534            "user" => {
535                let text = content_as_text(m.content.as_ref());
536                let images = extract_images(m.content.as_ref());
537                if let Some(imgs) = images {
538                    result.push(ChatMessage::user_with_images(text, imgs));
539                } else {
540                    result.push(ChatMessage::user(text));
541                }
542                i += 1;
543            }
544            "assistant" => {
545                if let Some(ref tool_calls) = m.tool_calls {
546                    // Reconstruct <tool_call> XML blocks from stored tool calls
547                    let mut text = content_as_text(m.content.as_ref());
548                    for tc in tool_calls {
549                        text.push_str("\n<tool_call>\n");
550                        let payload = serde_json::json!({
551                            "name": tc.function.name,
552                            "arguments": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
553                                .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()))
554                        });
555                        text.push_str(
556                            &serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_owned()),
557                        );
558                        text.push_str("\n</tool_call>");
559                    }
560                    result.push(ChatMessage::assistant(text));
561                } else {
562                    result.push(ChatMessage::assistant(content_as_text(m.content.as_ref())));
563                }
564                i += 1;
565            }
566            "tool" => {
567                // Collect consecutive tool messages into a single user message
568                let mut tool_responses = Vec::new();
569                while i < messages.len() && messages[i].role == "tool" {
570                    let tool_msg = &messages[i];
571                    let name = tool_msg.name.as_deref().unwrap_or("unknown");
572                    let content_text = content_as_text(tool_msg.content.as_ref());
573                    let response_value: serde_json::Value = if content_text.is_empty() {
574                        serde_json::Value::Null
575                    } else {
576                        serde_json::from_str(&content_text)
577                            .unwrap_or(serde_json::Value::String(content_text))
578                    };
579                    tool_responses.push(embacle::FunctionResponse {
580                        name: name.to_owned(),
581                        response: response_value,
582                    });
583                    i += 1;
584                }
585                let text = embacle::format_tool_results_as_text(&tool_responses);
586                result.push(ChatMessage::user(text));
587            }
588            other => {
589                warn!(role = other, "Unknown message role, mapping to user");
590                result.push(ChatMessage::user(content_as_text(m.content.as_ref())));
591                i += 1;
592            }
593        }
594    }
595
596    result
597}
598
599/// Convert a server `ToolDefinition` to core `ToolDefinition`
600fn server_tool_to_core(tool: &crate::openai_types::ToolDefinition) -> embacle::ToolDefinition {
601    embacle::ToolDefinition {
602        name: tool.function.name.clone(),
603        description: tool.function.description.clone().unwrap_or_default(),
604        parameters: tool.function.parameters.clone(),
605    }
606}
607
608/// Convert a server `ToolChoice` to core `ToolChoice`
609fn server_choice_to_core(choice: &ToolChoice) -> embacle::ToolChoice {
610    match choice {
611        ToolChoice::Mode(m) => match m.as_str() {
612            "none" => embacle::ToolChoice::None,
613            "required" => embacle::ToolChoice::Required,
614            _ => embacle::ToolChoice::Auto,
615        },
616        ToolChoice::Specific(s) => embacle::ToolChoice::Specific {
617            name: s.function.name.clone(),
618        },
619    }
620}
621
622/// Convert a server `ResponseFormatRequest` to core `ResponseFormat`
623fn server_format_to_core(format: &ResponseFormatRequest) -> embacle::ResponseFormat {
624    match format {
625        ResponseFormatRequest::Text => embacle::ResponseFormat::Text,
626        ResponseFormatRequest::JsonObject => embacle::ResponseFormat::JsonObject,
627        ResponseFormatRequest::JsonSchema { json_schema } => embacle::ResponseFormat::JsonSchema {
628            name: json_schema.name.clone(),
629            schema: json_schema.schema.clone(),
630        },
631    }
632}
633
634/// Convert `OpenAI` tool definitions to embacle `FunctionDeclaration` format
635fn tools_to_declarations(
636    tools: &[crate::openai_types::ToolDefinition],
637) -> Vec<FunctionDeclaration> {
638    tools
639        .iter()
640        .map(|t| FunctionDeclaration {
641            name: t.function.name.clone(),
642            description: t.function.description.clone().unwrap_or_default(),
643            parameters: t.function.parameters.clone(),
644        })
645        .collect()
646}
647
648/// Inject tool catalog into the last user message content
649///
650/// Used for providers that do not support system messages (e.g. Copilot CLI).
651/// The catalog is prepended to the last user message so the LLM sees it in
652/// the conversational flow rather than in a system prompt it cannot parse.
653fn inject_tool_catalog_as_user_message(messages: &mut [ChatMessage], catalog: &str) {
654    if let Some(last_user) = messages
655        .iter_mut()
656        .rev()
657        .find(|m| m.role == embacle::types::MessageRole::User)
658    {
659        let augmented = format!("{catalog}\n\n{}", last_user.content);
660        *last_user = ChatMessage::user(augmented);
661    } else {
662        // No user message found; this shouldn't happen in practice but
663        // handle gracefully by appending a user message with the catalog.
664        warn!("No user message found for tool catalog injection");
665    }
666}
667
668/// Check if `tool_choice` is explicitly "none"
669fn is_tool_choice_none(tool_choice: Option<&ToolChoice>) -> bool {
670    matches!(tool_choice, Some(ToolChoice::Mode(ref m)) if m == "none")
671}
672
673/// Generate a deterministic tool call ID from function name and index
674fn generate_tool_call_id(name: &str, index: usize) -> String {
675    format!("call_{name}_{index}")
676}
677
678/// Map a `RunnerError` to an appropriate HTTP status code and `OpenAI` error response
679fn runner_error_to_response(err: &RunnerError) -> Response {
680    let (status, error_type) = match err.kind {
681        ErrorKind::BinaryNotFound => (StatusCode::SERVICE_UNAVAILABLE, "provider_not_available"),
682        ErrorKind::AuthFailure => (StatusCode::UNAUTHORIZED, "authentication_error"),
683        ErrorKind::Timeout => (StatusCode::GATEWAY_TIMEOUT, "timeout_error"),
684        ErrorKind::ExternalService => (StatusCode::BAD_GATEWAY, "external_service_error"),
685        ErrorKind::Config => (StatusCode::BAD_REQUEST, "invalid_request_error"),
686        ErrorKind::Guardrail => (StatusCode::BAD_REQUEST, "guardrail_error"),
687        ErrorKind::Internal => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
688    };
689
690    error!(kind = ?err.kind, message = %err.message, "Runner error");
691    let body = ErrorResponse::new(error_type, &err.message);
692    (status, Json(body)).into_response()
693}
694
695/// Build an error response with a given status and message
696fn error_response(status: StatusCode, message: &str) -> Response {
697    let body = ErrorResponse::new("invalid_request_error", message);
698    (status, Json(body)).into_response()
699}
700
701/// Monotonic counter ensuring unique IDs even for requests within the same second
702static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
703
704/// Generate a unique completion ID
705///
706/// Combines the unix timestamp with a monotonically increasing counter
707/// to guarantee uniqueness across concurrent and rapid-fire requests.
708pub fn generate_id() -> String {
709    let ts = unix_timestamp();
710    let seq = ID_COUNTER.fetch_add(1, Ordering::Relaxed);
711    format!("chatcmpl-{ts:x}{seq:08x}")
712}
713
714/// Get current unix timestamp in seconds
715pub fn unix_timestamp() -> u64 {
716    SystemTime::now()
717        .duration_since(UNIX_EPOCH)
718        .map(|d| d.as_secs())
719        .unwrap_or(0)
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use crate::openai_types::{
726        ContentPart, FunctionObject, ImageUrlDetail, ToolCall, ToolCallFunction, ToolDefinition,
727    };
728    use embacle::types::MessageRole;
729
730    /// Helper to create a `ChatCompletionMessage` with plain text content
731    fn text_msg(role: &str, content: Option<&str>) -> ChatCompletionMessage {
732        ChatCompletionMessage {
733            role: role.to_owned(),
734            content: content.map(|c| MessageContent::Text(c.to_owned())),
735            tool_calls: None,
736            tool_call_id: None,
737            name: None,
738        }
739    }
740
741    #[test]
742    fn convert_messages_maps_roles() {
743        let openai_msgs = vec![
744            text_msg("system", Some("You are helpful")),
745            text_msg("user", Some("Hello")),
746            text_msg("assistant", Some("Hi there")),
747        ];
748
749        let messages = convert_messages(&openai_msgs);
750        assert_eq!(messages.len(), 3);
751        assert_eq!(messages[0].role, MessageRole::System);
752        assert_eq!(messages[1].role, MessageRole::User);
753        assert_eq!(messages[2].role, MessageRole::Assistant);
754    }
755
756    #[test]
757    fn convert_unknown_role_defaults_to_user() {
758        let openai_msgs = vec![text_msg("function", Some("result"))];
759
760        let messages = convert_messages(&openai_msgs);
761        assert_eq!(messages[0].role, MessageRole::User);
762    }
763
764    #[test]
765    fn convert_assistant_with_tool_calls() {
766        let openai_msgs = vec![ChatCompletionMessage {
767            role: "assistant".to_owned(),
768            content: None,
769            tool_calls: Some(vec![ToolCall {
770                index: 0,
771                id: "call_1".to_owned(),
772                tool_type: "function".to_owned(),
773                function: ToolCallFunction {
774                    name: "get_weather".to_owned(),
775                    arguments: r#"{"city":"Paris"}"#.to_owned(),
776                },
777            }]),
778            tool_call_id: None,
779            name: None,
780        }];
781
782        let messages = convert_messages(&openai_msgs);
783        assert_eq!(messages.len(), 1);
784        assert_eq!(messages[0].role, MessageRole::Assistant);
785        assert!(messages[0].content.contains("<tool_call>"));
786        assert!(messages[0].content.contains("get_weather"));
787        assert!(messages[0].content.contains("</tool_call>"));
788    }
789
790    #[test]
791    fn convert_tool_messages_to_user() {
792        let openai_msgs = vec![
793            ChatCompletionMessage {
794                role: "tool".to_owned(),
795                content: Some(MessageContent::Text(r#"{"temp":72}"#.to_owned())),
796                tool_calls: None,
797                tool_call_id: Some("call_1".to_owned()),
798                name: Some("get_weather".to_owned()),
799            },
800            ChatCompletionMessage {
801                role: "tool".to_owned(),
802                content: Some(MessageContent::Text(r#"{"time":"14:30"}"#.to_owned())),
803                tool_calls: None,
804                tool_call_id: Some("call_2".to_owned()),
805                name: Some("get_time".to_owned()),
806            },
807        ];
808
809        let messages = convert_messages(&openai_msgs);
810        // Consecutive tool messages should be merged into one user message
811        assert_eq!(messages.len(), 1);
812        assert_eq!(messages[0].role, MessageRole::User);
813        assert!(messages[0].content.contains("tool_result"));
814        assert!(messages[0].content.contains("get_weather"));
815        assert!(messages[0].content.contains("get_time"));
816    }
817
818    #[test]
819    fn convert_messages_none_content() {
820        let openai_msgs = vec![text_msg("user", None)];
821
822        let messages = convert_messages(&openai_msgs);
823        assert_eq!(messages[0].content, "");
824    }
825
826    #[test]
827    fn convert_multipart_user_message_extracts_images() {
828        let openai_msgs = vec![ChatCompletionMessage {
829            role: "user".to_owned(),
830            content: Some(MessageContent::Parts(vec![
831                ContentPart::Text {
832                    text: "What is this?".to_owned(),
833                },
834                ContentPart::ImageUrl {
835                    image_url: ImageUrlDetail {
836                        url: "data:image/png;base64,aGVsbG8=".to_owned(),
837                    },
838                },
839            ])),
840            tool_calls: None,
841            tool_call_id: None,
842            name: None,
843        }];
844
845        let messages = convert_messages(&openai_msgs);
846        assert_eq!(messages.len(), 1);
847        assert_eq!(messages[0].content, "What is this?");
848        let images = messages[0].images.as_ref().expect("images present");
849        assert_eq!(images.len(), 1);
850        assert_eq!(images[0].mime_type, "image/png");
851        assert_eq!(images[0].data, "aGVsbG8=");
852    }
853
854    #[test]
855    fn parse_data_uri_valid() {
856        let img = parse_data_uri("data:image/jpeg;base64,AAAA").expect("should parse");
857        assert_eq!(img.mime_type, "image/jpeg");
858        assert_eq!(img.data, "AAAA");
859    }
860
861    #[test]
862    fn parse_data_uri_invalid_format() {
863        assert!(parse_data_uri("https://example.com/image.png").is_none());
864        assert!(parse_data_uri("data:text/plain;base64,abc").is_none());
865        assert!(parse_data_uri("data:image/png;abc").is_none());
866    }
867
868    #[test]
869    fn convert_plain_string_content_backward_compat() {
870        let openai_msgs = vec![text_msg("user", Some("hello"))];
871        let messages = convert_messages(&openai_msgs);
872        assert_eq!(messages[0].content, "hello");
873        assert!(messages[0].images.is_none());
874    }
875
876    #[test]
877    fn tools_to_declarations_converts() {
878        let tools = vec![ToolDefinition {
879            tool_type: "function".to_owned(),
880            function: FunctionObject {
881                name: "search".to_owned(),
882                description: Some("Search the web".to_owned()),
883                parameters: Some(serde_json::json!({
884                    "type": "object",
885                    "properties": {"q": {"type": "string"}},
886                    "required": ["q"]
887                })),
888            },
889        }];
890
891        let decls = tools_to_declarations(&tools);
892        assert_eq!(decls.len(), 1);
893        assert_eq!(decls[0].name, "search");
894        assert_eq!(decls[0].description, "Search the web");
895        assert!(decls[0].parameters.is_some());
896    }
897
898    #[test]
899    fn tool_choice_none_detection() {
900        let none_choice = ToolChoice::Mode("none".to_owned());
901        assert!(is_tool_choice_none(Some(&none_choice)));
902        let auto_choice = ToolChoice::Mode("auto".to_owned());
903        assert!(!is_tool_choice_none(Some(&auto_choice)));
904        assert!(!is_tool_choice_none(None));
905    }
906
907    #[test]
908    fn content_as_text_none() {
909        assert_eq!(content_as_text(None), "");
910    }
911
912    #[test]
913    fn content_as_text_plain() {
914        let content = MessageContent::Text("hello".to_owned());
915        assert_eq!(content_as_text(Some(&content)), "hello");
916    }
917
918    #[test]
919    fn generate_tool_call_id_format() {
920        let id = generate_tool_call_id("get_weather", 0);
921        assert_eq!(id, "call_get_weather_0");
922    }
923
924    #[test]
925    fn generate_id_has_prefix() {
926        let id = generate_id();
927        assert!(id.starts_with("chatcmpl-"));
928    }
929
930    #[test]
931    fn error_maps_binary_not_found_to_503() {
932        let err = RunnerError::binary_not_found("claude");
933        let (status, _) = match err.kind {
934            ErrorKind::BinaryNotFound => {
935                (StatusCode::SERVICE_UNAVAILABLE, "provider_not_available")
936            }
937            _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
938        };
939        assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
940    }
941
942    #[test]
943    fn error_maps_auth_to_401() {
944        let err = RunnerError::auth_failure("bad token");
945        let (status, _) = match err.kind {
946            ErrorKind::AuthFailure => (StatusCode::UNAUTHORIZED, "authentication_error"),
947            _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
948        };
949        assert_eq!(status, StatusCode::UNAUTHORIZED);
950    }
951
952    #[test]
953    fn error_maps_timeout_to_504() {
954        let err = RunnerError::timeout("too slow");
955        let (status, _) = match err.kind {
956            ErrorKind::Timeout => (StatusCode::GATEWAY_TIMEOUT, "timeout_error"),
957            _ => (StatusCode::INTERNAL_SERVER_ERROR, "server_error"),
958        };
959        assert_eq!(status, StatusCode::GATEWAY_TIMEOUT);
960    }
961
962    #[test]
963    fn inject_tool_catalog_as_user_message_prepends_to_last_user() {
964        let mut messages = vec![
965            ChatMessage::user("First question"),
966            ChatMessage::assistant("Some answer"),
967            ChatMessage::user("What is the weather?"),
968        ];
969        let catalog = "## Available Tools\n- get_weather: Get the weather";
970
971        inject_tool_catalog_as_user_message(&mut messages, catalog);
972
973        assert_eq!(messages.len(), 3);
974        assert!(messages[2].content.starts_with("## Available Tools"));
975        assert!(messages[2].content.contains("What is the weather?"));
976        // First user message should be untouched
977        assert_eq!(messages[0].content, "First question");
978    }
979
980    #[test]
981    fn inject_tool_catalog_as_user_message_single_user() {
982        let mut messages = vec![
983            ChatMessage::system("You are helpful"),
984            ChatMessage::user("Hello"),
985        ];
986        let catalog = "## Tools\nsome tools";
987
988        inject_tool_catalog_as_user_message(&mut messages, catalog);
989
990        assert!(messages[1].content.starts_with("## Tools"));
991        assert!(messages[1].content.contains("Hello"));
992    }
993
994    #[test]
995    fn wants_json_matches_json_formats() {
996        use embacle::types::ResponseFormat;
997
998        assert!(!wants_json(None));
999        assert!(!wants_json(Some(&ResponseFormat::Text)));
1000        assert!(wants_json(Some(&ResponseFormat::JsonObject)));
1001        assert!(wants_json(Some(&ResponseFormat::JsonSchema {
1002            name: "test".to_owned(),
1003            schema: serde_json::json!({}),
1004        })));
1005    }
1006
1007    #[test]
1008    fn strip_json_fences_removes_markdown_wrapper() {
1009        let fenced = "```json\n{\"key\":\"value\"}\n```".to_owned();
1010        assert_eq!(strip_json_fences(fenced, true), "{\"key\":\"value\"}");
1011    }
1012
1013    #[test]
1014    fn strip_json_fences_passes_through_in_text_mode() {
1015        let fenced = "```json\n{\"key\":\"value\"}\n```".to_owned();
1016        assert_eq!(strip_json_fences(fenced.clone(), false), fenced);
1017    }
1018
1019    #[test]
1020    fn strip_json_fences_leaves_clean_json_unchanged() {
1021        let clean = "{\"key\":\"value\"}".to_owned();
1022        assert_eq!(strip_json_fences(clean.clone(), true), clean);
1023    }
1024}