lmkit 0.1.0

Multi-provider AI API client (OpenAI, Anthropic, Google Gemini, Aliyun, Ollama, Zhipu; chat, embed incl. Gemini, rerank, image, audio stubs)
Documentation
use super::*;
use crate::chat::{ToolChoice, ToolDefinition};
use crate::config::Provider;
use futures::StreamExt;
use wiremock::matchers::{body_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

fn test_config(server: &MockServer) -> ProviderConfig {
    ProviderConfig::new(
        Provider::OpenAI,
        "test-key",
        server.uri().to_string(),
        "gpt-4o-mini",
    )
}

#[tokio::test]
async fn chat_success_returns_assistant_content() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .and(header("Authorization", "Bearer test-key"))
        .and(body_json(serde_json::json!({
            "model": "gpt-4o-mini",
            "messages": [{ "role": "user", "content": "hello" }],
            "temperature": 0.2,
        })))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": { "role": "assistant", "content": "hi there" }
            }]
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let reply = chat.chat("hello").await.unwrap();
    assert_eq!(reply, "hi there");
}

#[tokio::test]
async fn complete_multi_turn_and_system_in_body() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .and(body_json(serde_json::json!({
            "model": "gpt-4o-mini",
            "messages": [
                { "role": "system", "content": "You are helpful." },
                { "role": "user", "content": "hi" },
                { "role": "assistant", "content": "hello" },
                { "role": "user", "content": "bye" }
            ],
            "temperature": 0.2,
        })))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": { "role": "assistant", "content": "see you" },
                "finish_reason": "stop"
            }]
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let req = ChatRequest {
        messages: vec![
            ChatMessage::system("You are helpful."),
            ChatMessage::user("hi"),
            ChatMessage::assistant("hello"),
            ChatMessage::user("bye"),
        ],
        ..Default::default()
    };
    let r = chat.complete(&req).await.unwrap();
    assert_eq!(r.content.as_deref(), Some("see you"));
}

#[tokio::test]
async fn complete_serializes_tool_choice_and_sampling() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .and(body_json(serde_json::json!({
            "model": "gpt-4o-mini",
            "messages": [{ "role": "user", "content": "hi" }],
            "tools": [{
                "type": "function",
                "function": {
                    "name": "get_weather",
                    "parameters": { "type": "object", "properties": {} }
                }
            }],
            "tool_choice": {
                "type": "function",
                "function": { "name": "get_weather" }
            },
            "temperature": 0.7,
            "max_tokens": 512,
            "top_p": 0.9,
        })))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": { "role": "assistant", "content": "ok" }
            }]
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let req = ChatRequest {
        messages: vec![ChatMessage::user("hi")],
        tools: Some(vec![ToolDefinition::function(
            "get_weather",
            serde_json::json!({ "type": "object", "properties": {} }),
        )]),
        tool_choice: Some(ToolChoice::Tool("get_weather".into())),
        temperature: Some(0.7),
        max_tokens: Some(512),
        top_p: Some(0.9),
    };
    let r = chat.complete(&req).await.unwrap();
    assert_eq!(r.content.as_deref(), Some("ok"));
}

#[tokio::test]
async fn complete_returns_tool_calls() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": {
                    "role": "assistant",
                    "content": null,
                    "tool_calls": [{
                        "id": "call_1",
                        "type": "function",
                        "function": { "name": "get_weather", "arguments": "{\"city\":\"NYC\"}" }
                    }]
                },
                "finish_reason": "tool_calls"
            }]
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let r = chat
        .complete(&ChatRequest::single_user("weather?"))
        .await
        .unwrap();
    assert!(r.content.is_none());
    let tc = r.tool_calls.as_ref().unwrap();
    assert_eq!(tc.len(), 1);
    assert_eq!(tc[0].id, "call_1");
    assert_eq!(tc[0].function.name, "get_weather");
    assert_eq!(tc[0].function.arguments, "{\"city\":\"NYC\"}");
    assert_eq!(r.finish_reason, Some(FinishReason::ToolCalls));
}

#[tokio::test]
async fn complete_skips_malformed_tool_calls_keeps_valid() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": {
                    "role": "assistant",
                    "content": null,
                    "tool_calls": [
                        { "type": "function", "function": { "name": "bad", "arguments": "{}" } },
                        {
                            "id": "call_ok",
                            "type": "function",
                            "function": { "name": "good", "arguments": "{}" }
                        }
                    ]
                },
                "finish_reason": "tool_calls"
            }]
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let r = chat.complete(&ChatRequest::single_user("x")).await.unwrap();
    let tc = r.tool_calls.as_ref().unwrap();
    assert_eq!(tc.len(), 1);
    assert_eq!(tc[0].id, "call_ok");
    assert_eq!(tc[0].function.name, "good");
}

#[tokio::test]
async fn complete_stream_yields_tool_call_deltas() {
    let server = MockServer::start().await;
    let sse_body = concat!(
        "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_x\",\"function\":{\"name\":\"fn\"}}]},\"finish_reason\":null}]}\n\n",
        "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"a\\\":1}\"}}]},\"finish_reason\":null}]}\n\n",
        "data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
        "data: [DONE]\n\n",
    );
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .and(body_json(serde_json::json!({
            "model": "gpt-4o-mini",
            "messages": [{ "role": "user", "content": "x" }],
            "temperature": 0.2,
            "stream": true,
        })))
        .respond_with(
            ResponseTemplate::new(200)
                .insert_header("content-type", "text/event-stream")
                .set_body_string(sse_body),
        )
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let mut stream = chat
        .complete_stream(&ChatRequest::single_user("x"))
        .await
        .unwrap();
    let mut chunks = Vec::new();
    while let Some(item) = stream.next().await {
        chunks.push(item.unwrap());
    }
    assert_eq!(chunks.len(), 3);
    assert!(chunks[0].tool_call_deltas.is_some());
    let d0 = &chunks[0].tool_call_deltas.as_ref().unwrap()[0];
    assert_eq!(d0.index, 0);
    assert_eq!(d0.id.as_deref(), Some("call_x"));
    assert_eq!(d0.function_name.as_deref(), Some("fn"));
    let d1 = &chunks[1].tool_call_deltas.as_ref().unwrap()[0];
    assert_eq!(d1.function_arguments.as_deref(), Some("{\"a\":1}"));
    assert_eq!(chunks[2].finish_reason, Some(FinishReason::ToolCalls));
}

#[tokio::test]
async fn chat_base_url_trailing_slash_normalized() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": [{
                "message": { "role": "assistant", "content": "ok" }
            }]
        })))
        .mount(&server)
        .await;

    let mut cfg = test_config(&server);
    cfg.base_url = format!("{}/", server.uri());
    let chat = OpenaiCompatChat::new(&cfg).unwrap();
    assert_eq!(chat.chat("x").await.unwrap(), "ok");
}

#[tokio::test]
async fn chat_empty_choices_yields_missing_field() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "choices": []
        })))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::MissingField(name) => assert_eq!(name, "choices[0]"),
        other => panic!("expected MissingField, got {:?}", other),
    }
}

#[tokio::test]
async fn chat_non_success_maps_to_api_error() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(401).set_body_string("invalid key"))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::Api { status, message } => {
            assert_eq!(status, 401);
            assert_eq!(message, "invalid key");
        }
        other => panic!("expected Api, got {:?}", other),
    }
}

#[tokio::test]
async fn chat_success_body_not_json_yields_parse() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(200).set_body_string("not json"))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::Parse(_) => {}
        other => panic!("expected Parse, got {:?}", other),
    }
}

#[tokio::test]
async fn chat_stream_yields_deltas_and_finish() {
    let server = MockServer::start().await;
    let sse_body = concat!(
        "data: {\"choices\":[{\"delta\":{\"content\":\"he\"},\"finish_reason\":null}]}\n\n",
        "data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
        "data: [DONE]\n\n",
    );
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .and(header("Authorization", "Bearer test-key"))
        .and(body_json(serde_json::json!({
            "model": "gpt-4o-mini",
            "messages": [{ "role": "user", "content": "hello" }],
            "temperature": 0.2,
            "stream": true,
        })))
        .respond_with(
            ResponseTemplate::new(200)
                .insert_header("content-type", "text/event-stream")
                .set_body_string(sse_body),
        )
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let mut stream = chat.chat_stream("hello").await.unwrap();
    let mut chunks = Vec::new();
    while let Some(item) = stream.next().await {
        chunks.push(item.unwrap());
    }
    assert_eq!(chunks.len(), 2);
    assert_eq!(chunks[0].delta.as_deref(), Some("he"));
    assert!(chunks[0].finish_reason.is_none());
    assert_eq!(chunks[1].delta, None);
    assert_eq!(chunks[1].finish_reason, Some(FinishReason::Stop));
}

#[tokio::test]
async fn chat_stream_http_error_before_body() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path("/chat/completions"))
        .respond_with(ResponseTemplate::new(429).set_body_string("rate"))
        .mount(&server)
        .await;

    let chat = OpenaiCompatChat::new(&test_config(&server)).unwrap();
    let err = match chat.chat_stream("x").await {
        Err(e) => e,
        Ok(_) => panic!("expected error"),
    };
    match err {
        Error::Api { status, message } => {
            assert_eq!(status, 429);
            assert_eq!(message, "rate");
        }
        other => panic!("expected Api, got {:?}", other),
    }
}