lmkit 0.1.0

Multi-provider AI API client (OpenAI, Anthropic, Google Gemini, Aliyun, Ollama, Zhipu; chat, embed incl. Gemini, rerank, image, audio stubs)
Documentation
use super::*;
use crate::chat::{ChatChunk, ToolChoice, ToolDefinition};
use crate::config::Provider;
use crate::error::Error;
use futures::StreamExt;
use wiremock::matchers::{method, path_regex, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};

fn test_config(server: &MockServer) -> ProviderConfig {
    ProviderConfig::new(
        Provider::Google,
        "AIza-test",
        format!("{}/v1beta", server.uri()),
        "gemini-2.0-flash",
    )
}

#[tokio::test]
async fn generate_content_success_returns_text() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": {
                    "parts": [{ "text": "hi there" }],
                    "role": "model"
                }
            }]
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let reply = chat.chat("hello").await.unwrap();
    assert_eq!(reply, "hi there");
}

#[tokio::test]
async fn complete_multi_turn_and_system_instruction() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": { "parts": [{ "text": "ok" }] }
            }]
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let req = ChatRequest {
        messages: vec![
            ChatMessage::system("Be brief."),
            ChatMessage::user("hi"),
            ChatMessage::assistant("hey"),
            ChatMessage::user("bye"),
        ],
        ..Default::default()
    };
    let r = chat.complete(&req).await.unwrap();
    assert_eq!(r.content.as_deref(), Some("ok"));
}

#[tokio::test]
async fn complete_errors_when_tool_message_missing_function_name() {
    let server = MockServer::start().await;
    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let req = ChatRequest {
        messages: vec![ChatMessage::user("hi"), ChatMessage::tool("call_1", "{}")],
        ..Default::default()
    };
    let err = chat.complete(&req).await.unwrap_err();
    match err {
        Error::MissingField(name) => assert_eq!(name, "tool.name"),
        other => panic!("expected MissingField(tool.name), got {:?}", other),
    }
}

#[tokio::test]
async fn complete_with_tools_tool_choice_and_sampling_succeeds() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": { "parts": [{ "text": "ok" }] }
            }]
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let req = ChatRequest {
        messages: vec![ChatMessage::user("hi")],
        tools: Some(vec![ToolDefinition::function(
            "get_weather",
            serde_json::json!({ "type": "object", "properties": {} }),
        )]),
        tool_choice: Some(ToolChoice::Tool("get_weather".into())),
        temperature: Some(0.7),
        max_tokens: Some(512),
        top_p: Some(0.85),
    };
    let r = chat.complete(&req).await.unwrap();
    assert_eq!(r.content.as_deref(), Some("ok"));
}

#[tokio::test]
async fn complete_parses_function_call() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": {
                    "parts": [{
                        "functionCall": {
                            "name": "get_weather",
                            "args": { "city": "NYC" }
                        }
                    }]
                },
                "finishReason": "STOP"
            }]
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let r = chat.complete(&ChatRequest::single_user("w")).await.unwrap();
    let tc = r.tool_calls.as_ref().unwrap();
    assert_eq!(tc[0].function.name, "get_weather");
    assert!(tc[0].function.arguments.contains("NYC"));
    assert_eq!(r.finish_reason, Some(FinishReason::ToolCalls));
}

#[tokio::test]
async fn generate_content_base_url_trailing_slash_normalized() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": {
                    "parts": [{ "text": "ok" }]
                }
            }]
        })))
        .mount(&server)
        .await;

    let mut cfg = test_config(&server);
    cfg.base_url = format!("{}/v1beta/", server.uri());
    let chat = GoogleGeminiChat::new(&cfg).unwrap();
    assert_eq!(chat.chat("x").await.unwrap(), "ok");
}

#[tokio::test]
async fn generate_content_empty_text_parts_yields_missing_field() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [{
                "content": {
                    "parts": [{ "text": "" }]
                }
            }]
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::MissingField(name) => {
            assert_eq!(name, "response content");
        }
        other => panic!("expected MissingField, got {:?}", other),
    }
}

#[tokio::test]
async fn generate_content_empty_candidates_includes_prompt_feedback() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .and(query_param("key", "AIza-test"))
        .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
            "candidates": [],
            "promptFeedback": { "blockReason": "BLOCK_REASON_UNSPECIFIED" }
        })))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::Parse(msg) => {
            assert!(
                msg.contains("no candidates") && msg.contains("blockReason"),
                "unexpected message: {msg}"
            );
        }
        other => panic!("expected Parse, got {:?}", other),
    }
}

#[tokio::test]
async fn generate_content_non_success_maps_to_api_error() {
    let server = MockServer::start().await;
    Mock::given(method("POST"))
        .and(path_regex(r"/v1beta/models/[^/]+(:|%3A)generateContent"))
        .respond_with(ResponseTemplate::new(403).set_body_string("forbidden"))
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let err = chat.chat("x").await.unwrap_err();
    match err {
        Error::Api { status, message } => {
            assert_eq!(status, 403);
            assert_eq!(message, "forbidden");
        }
        other => panic!("expected Api, got {:?}", other),
    }
}

#[tokio::test]
async fn stream_generate_content_yields_text_chunk() {
    let server = MockServer::start().await;
    let sse_body = "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hi\"}]}}]}\n\n";
    Mock::given(method("POST"))
        .and(path_regex(
            r"/v1beta/models/[^/]+(:|%3A)streamGenerateContent",
        ))
        .and(query_param("key", "AIza-test"))
        .respond_with(
            ResponseTemplate::new(200)
                .insert_header("content-type", "text/event-stream")
                .set_body_string(sse_body),
        )
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let mut stream = chat.chat_stream("hello").await.unwrap();
    let chunk = stream.next().await.unwrap().unwrap();
    assert_eq!(chunk, ChatChunk::delta("Hi"));
    assert!(stream.next().await.is_none());
}

#[tokio::test]
async fn stream_maps_finish_reason_tool_calls_when_function_call_with_stop() {
    let server = MockServer::start().await;
    let payload = serde_json::json!({
        "candidates": [{
            "content": { "parts": [{ "functionCall": { "name": "fn", "args": { "a": 1 } } }] },
            "finishReason": "STOP"
        }]
    });
    let sse_body = format!("data: {payload}\n\n");
    Mock::given(method("POST"))
        .and(path_regex(
            r"/v1beta/models/[^/]+(:|%3A)streamGenerateContent",
        ))
        .and(query_param("key", "AIza-test"))
        .respond_with(
            ResponseTemplate::new(200)
                .insert_header("content-type", "text/event-stream")
                .set_body_string(sse_body),
        )
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let mut stream = chat
        .complete_stream(&ChatRequest::single_user("x"))
        .await
        .unwrap();
    let chunk = stream.next().await.unwrap().unwrap();
    assert!(chunk.tool_call_deltas.is_some());
    assert_eq!(chunk.finish_reason, Some(FinishReason::ToolCalls));
}

#[tokio::test]
async fn stream_yields_function_call_delta() {
    let server = MockServer::start().await;
    let payload = serde_json::json!({"candidates":[{"content":{"parts":[{"functionCall":{"name":"fn","args":{"a":1}}}]}}]});
    let sse_body = format!("data: {payload}\n\n");
    Mock::given(method("POST"))
        .and(path_regex(
            r"/v1beta/models/[^/]+(:|%3A)streamGenerateContent",
        ))
        .and(query_param("key", "AIza-test"))
        .respond_with(
            ResponseTemplate::new(200)
                .insert_header("content-type", "text/event-stream")
                .set_body_string(sse_body),
        )
        .mount(&server)
        .await;

    let chat = GoogleGeminiChat::new(&test_config(&server)).unwrap();
    let mut stream = chat
        .complete_stream(&ChatRequest::single_user("x"))
        .await
        .unwrap();
    let chunk = stream.next().await.unwrap().unwrap();
    assert!(chunk.tool_call_deltas.is_some());
    let d = &chunk.tool_call_deltas.as_ref().unwrap()[0];
    assert_eq!(d.function_name.as_deref(), Some("fn"));
    assert!(d.function_arguments.as_ref().unwrap().contains("1"));
}