use markhor_core::chat::chat::{
ChatApi, ChatOptions, ContentPart, FinishReason,
Message, ToolChoice, ToolDefinition, ToolParameterSchema, ToolResult,
};
use markhor_core::embedding::{Embedder, EmbeddingUseCase, EmbeddingError};
use serde_json::json;
use tracing::error;
use markhor_extensions::gemini::{GeminiChatClient, GeminiEmbedder};
use tokio;
use tracing::info;
use tracing_subscriber; use std::env;
fn setup_tracing() {
let _ = tracing_subscriber::fmt::try_init();
}
async fn get_api_key() -> String {
println!("Current directory: {:?}", std::env::current_dir());
dotenv::dotenv().ok();
let api_key = env::var("GEMINI_API_KEY")
.expect("GEMINI_API_KEY environment variable not set.");
api_key
}
async fn get_chat_client() -> GeminiChatClient {
GeminiChatClient::new(get_api_key().await).expect("Failed to create GeminiChatClient")
}
#[tokio::test]
#[ignore]
async fn test_gemini_list_models_integration() {
setup_tracing();
let client = get_chat_client().await;
let result = client.list_models().await;
if let Err(ref e) = result {
error!(error=%e, "list_models failed");
}
assert!(result.is_ok());
let models = result.unwrap();
assert!(!models.is_empty());
assert!(models.iter().any(|m| m.id.contains("gemini")));
info!("Found {} models.", models.len());
for model in models.iter().take(5) { info!("Model: id={}, description={:?}, context_window={:?}, max_output={:?}",
model.id, model.description, model.context_window, model.max_output_tokens);
}
}
#[tokio::test]
#[ignore]
async fn test_gemini_generate_simple_integration() {
setup_tracing();
let client = get_chat_client().await;
let messages = vec![
Message::user("Hi there, what's the capital of France?"),
];
let options = ChatOptions {
model_id: Some("gemini-1.5-flash-latest".to_string()),
temperature: Some(0.7),
max_tokens: Some(50),
..Default::default()
};
let result = client.generate(&messages, &options).await;
if let Err(ref e) = result {
error!(error = %e, "generate failed");
}
assert!(result.is_ok());
let response = result.unwrap();
info!(response = ?response, "Received response");
assert!(!response.content.is_empty());
assert!(matches!(response.content[0], ContentPart::Text(_)));
let text_response = match &response.content[0] {
ContentPart::Text(t) => t,
_ => panic!("Expected text response"),
};
assert!(text_response.to_lowercase().contains("paris"));
assert!(response.finish_reason.is_some());
assert!(response.usage.is_some());
assert_eq!(response.model_id, Some("gemini-1.5-flash-latest".to_string()));
}
#[tokio::test]
#[ignore]
async fn test_gemini_generate_tool_use_integration() {
setup_tracing();
let client = get_chat_client().await;
let get_weather_tool = ToolDefinition {
name: "get_current_weather".to_string(),
description: "Get the current weather in a given location".to_string(),
parameters: ToolParameterSchema {
schema_type: "object".to_string(),
properties: serde_json::from_value(json!({
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
})).unwrap(),
required: vec!["location".to_string()],
},
};
let mut messages_1 = vec![
Message::user("What is the weather like in Boston? I prefer temperatures in Celsius."),
];
let options_1 = ChatOptions {
model_id: Some("gemini-1.5-flash-latest".to_string()), tools: Some(vec![get_weather_tool.clone()]),
tool_choice: Some(ToolChoice::Auto),
..Default::default()
};
info!("Sending initial request expecting tool call...");
let result_1 = client.generate(&messages_1, &options_1).await;
if let Err(ref e) = result_1 {
error!(error = %e, "generate (call 1) failed");
}
assert!(result_1.is_ok());
let response_1 = result_1.unwrap(); info!(response = ?response_1, "Received response 1 (expecting tool call)");
assert!(!response_1.tool_calls.is_empty());
let tool_call_request = response_1.tool_calls.first().unwrap();
assert_eq!(tool_call_request.name, "get_current_weather");
let assistant_message_from_response_1 = Message::assistant_response(
response_1.content.clone(), response_1.tool_calls.clone(), );
let tool_call_id = tool_call_request.id.clone(); let tool_arguments = &tool_call_request.arguments; assert!(tool_arguments.is_object());
assert!(tool_arguments["location"].to_string().contains("Boston"));
info!(call_id=%tool_call_id, args=?tool_arguments, "Simulating tool execution");
let tool_result_content = json!({
"temperature": "30",
"unit": "celsius",
"description": "Sunny"
});
let tool_result = ToolResult {
call_id: tool_call_request.id.clone(), name: tool_call_request.name.clone(), content: tool_result_content,
};
let tool_result_message = Message::Tool(vec![tool_result]);
let messages_2 = vec![
messages_1.pop().unwrap(), assistant_message_from_response_1, tool_result_message, ];
info!(history = ?messages_2, "Sending second request with tool result and full history...");
let options_2 = ChatOptions {
model_id: Some("gemini-1.5-flash-latest".to_string()),
tools: None, tool_choice: Some(ToolChoice::None), ..Default::default()
};
let result_2 = client.generate(&messages_2, &options_1).await;
if let Err(ref e) = result_2 {
error!(error = %e, "generate (call 2) failed");
}
assert!(result_2.is_ok(), "Second generate call failed");
let response_2 = result_2.unwrap();
info!(response = ?response_2, "Received response 2 (expecting final answer)");
assert!(response_2.finish_reason == Some(FinishReason::Stop) || response_2.finish_reason == Some(FinishReason::Length));
assert!(response_2.tool_calls.is_empty());
assert!(!response_2.content.is_empty());
assert!(matches!(response_2.content[0], ContentPart::Text(_)));
let text_response = match &response_2.content[0] {
ContentPart::Text(t) => t,
_ => panic!("Expected text response"),
};
assert!(text_response.contains("30") || text_response.to_lowercase().contains("sunny"));
assert!(text_response.to_lowercase().contains("boston")); info!("Final text response: {}", text_response);
}
#[tokio::test]
async fn integration_test_gemini_embed_success() {
let api_key = get_api_key().await;
let model_name = "embedding-001";
let embedder = match GeminiEmbedder::new(api_key, model_name) {
Ok(e) => e,
Err(e) => panic!("Failed to create GeminiEmbedder: {}", e),
};
let texts = ["Hello Gemini!", "This is an integration test."];
let result = embedder.embed(&texts).await;
match result {
Ok(embeddings) => {
assert_eq!(embeddings.len(), texts.len(), "Should return same number of embeddings as inputs");
let expected_dims = embedder.dimensions();
assert!(expected_dims.is_some(), "Dimensions should be known for {}", model_name);
let expected_dims = expected_dims.unwrap();
for (i, embedding) in embeddings.iter().enumerate() {
assert_eq!(embedding.0.len(), expected_dims, "Embedding {} should have correct dimensions", i);
assert!(embedding.0.iter().any(|&v| v != 0.0), "Embedding {} should not be all zeros", i);
}
assert_eq!(embedder.intended_use_case(), EmbeddingUseCase::General, "Default use case should be General");
}
Err(e) => {
panic!("Embedding failed unexpectedly: {}", e);
}
}
}
#[tokio::test]
#[ignore]
async fn integration_test_gemini_embed_with_task_type() {
let api_key = get_api_key().await;
let model_name = "embedding-001";
let task_type = "RETRIEVAL_DOCUMENT".to_string();
let embedder = match GeminiEmbedder::new_with_options(
api_key,
model_name,
Some(task_type.clone()),
None,
None,
) {
Ok(e) => e,
Err(e) => panic!("Failed to create GeminiEmbedder: {}", e),
};
assert_eq!(embedder.intended_use_case(), EmbeddingUseCase::RetrievalDocument, "Use case should be mapped correctly");
let texts = ["This document is intended for retrieval."];
let result = embedder.embed(&texts).await;
match result {
Ok(embeddings) => {
assert_eq!(embeddings.len(), texts.len(), "Should return embedding for the input");
let expected_dims = embedder.dimensions().unwrap_or(0);
assert_ne!(expected_dims, 0, "Dimensions should be > 0");
assert_eq!(embeddings[0].0.len(), expected_dims, "Embedding should have correct dimensions");
}
Err(e) => {
panic!("Embedding failed unexpectedly with task_type '{}': {}", task_type, e);
}
}
}
#[tokio::test]
#[ignore]
async fn integration_test_gemini_embed_invalid_api_key() {
let invalid_api_key = "THIS_IS_NOT_A_VALID_API_KEY";
let model_name = "embedding-001";
let embedder = match GeminiEmbedder::new(invalid_api_key, model_name) {
Ok(e) => e,
Err(e) => panic!("Failed to create GeminiEmbedder (should succeed): {}", e),
};
let texts = ["Testing with an invalid key."];
let result = embedder.embed(&texts).await;
assert!(result.is_err(), "Embedding should fail with invalid API key");
let err = result.err().unwrap();
assert!(
matches!(err, EmbeddingError::InvalidRequest { .. }),
"Error should be an ApiError variant, got {:?}", err
);
assert!(
format!("{}", err).contains("API key"),
"Error message should include 'API key', got {:?}", err
);
}
#[tokio::test]
#[ignore]
async fn integration_test_gemini_embed_invalid_model_name() {
let api_key = get_api_key().await;
let invalid_model_name = "non-existent-embedding-model-foobar";
let embedder = match GeminiEmbedder::new(api_key, invalid_model_name) {
Ok(e) => e,
Err(e) => panic!("Failed to create GeminiEmbedder: {}", e),
};
let texts = ["Testing with an invalid model."];
let result = embedder.embed(&texts).await;
assert!(result.is_err(), "Embedding should fail with invalid model name");
let err = result.err().unwrap();
assert!(
matches!(err, EmbeddingError::ModelNotFound( .. )),
"Error should be ModelNotFound variant, got {:?}", err
);
if let EmbeddingError::Api { status, message, .. } = err {
println!("Received ApiError status: {:?}, message: {}", status, message);
}
}
#[tokio::test]
#[ignore]
async fn integration_test_gemini_batch_limit_error() {
let api_key = get_api_key().await;
let model_name = "embedding-001";
let embedder = match GeminiEmbedder::new(api_key, model_name) {
Ok(e) => e,
Err(e) => panic!("Failed to create GeminiEmbedder: {}", e),
};
let large_batch: Vec<&str> = (0..101).map(|_| "text").collect();
let result = embedder.embed(&large_batch).await;
assert!(result.is_err(), "Embedding should fail when batch size exceeds limit");
let err = result.err().unwrap();
assert!(
matches!(err, EmbeddingError::BatchTooLarge { limit: Some(100), actual: 101 }),
"Error should be BatchTooLarge with correct details, got {:?}", err
);
}