#![cfg(all(feature = "api", feature = "embeddings"))]
use axum::{
body::Body,
http::{Request, StatusCode},
};
use serde_json::json;
use tower::ServiceExt;
use kreuzberg::{
ExtractionConfig,
api::{EmbedResponse, create_router},
};
#[tokio::test]
async fn test_embed_valid_texts() {
let app = create_router(ExtractionConfig::default());
let request_body = json!({
"texts": ["Hello world", "Second text"]
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
assert_eq!(embed_response.count, 2);
assert_eq!(embed_response.embeddings.len(), 2);
assert!(embed_response.dimensions > 0);
assert!(!embed_response.model.is_empty());
for embedding in &embed_response.embeddings {
assert_eq!(embedding.len(), embed_response.dimensions);
}
}
#[tokio::test]
async fn test_embed_empty_texts() {
let app = create_router(ExtractionConfig::default());
let request_body = json!({
"texts": []
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
async fn test_embed_with_custom_config() {
let app = create_router(ExtractionConfig::default());
let request_body = json!({
"texts": ["Test embedding with custom config"],
"config": {
"model": {
"type": "preset",
"name": "fast"
},
"batch_size": 32
}
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
assert_eq!(embed_response.count, 1);
assert_eq!(embed_response.embeddings.len(), 1);
assert_eq!(embed_response.model, "fast");
}
#[tokio::test]
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
async fn test_embed_single_text() {
let app = create_router(ExtractionConfig::default());
let request_body = json!({
"texts": ["Single text for embedding"]
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
assert_eq!(embed_response.count, 1);
assert_eq!(embed_response.embeddings.len(), 1);
}
#[tokio::test]
async fn test_embed_batch() {
let app = create_router(ExtractionConfig::default());
let texts: Vec<String> = (0..10).map(|i| format!("Test text number {}", i)).collect();
let request_body = json!({
"texts": texts
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
assert_eq!(embed_response.count, 10);
assert_eq!(embed_response.embeddings.len(), 10);
let first_dim = embed_response.embeddings[0].len();
for embedding in &embed_response.embeddings {
assert_eq!(embedding.len(), first_dim);
}
}
#[tokio::test]
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
async fn test_embed_long_text() {
let app = create_router(ExtractionConfig::default());
let long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. ".repeat(100);
let request_body = json!({
"texts": [long_text]
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
assert_eq!(embed_response.count, 1);
assert_eq!(embed_response.embeddings.len(), 1);
}
#[tokio::test]
async fn test_embed_malformed_json() {
let app = create_router(ExtractionConfig::default());
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from("{invalid json}"))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_embed_rejects_json_array() {
let app = create_router(ExtractionConfig::default());
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(r#"[["text1"], {"texts": ["text2"]}]"#))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert!(
response.status() == StatusCode::BAD_REQUEST || response.status() == StatusCode::UNPROCESSABLE_ENTITY,
"Expected 400 or 422, got {}",
response.status()
);
}
#[tokio::test]
async fn test_embed_rejects_simple_json_array() {
let app = create_router(ExtractionConfig::default());
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(r#"["text1", "text2", "text3"]"#))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("Failed to read response body");
let error_response: serde_json::Value = serde_json::from_slice(&body).expect("Failed to parse error response");
assert!(
error_response["message"]
.as_str()
.map(|msg| msg.contains("array") || msg.contains("object"))
.unwrap_or(false)
);
}
#[tokio::test]
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
async fn test_embed_deterministic() {
let app = create_router(ExtractionConfig::default());
let request_body = json!({
"texts": ["Deterministic test"]
});
let response1 = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response1.status(), StatusCode::OK);
let body1 = axum::body::to_bytes(response1.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response1: EmbedResponse = serde_json::from_slice(&body1).expect("Failed to deserialize");
let response2 = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_body).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response2.status(), StatusCode::OK);
let body2 = axum::body::to_bytes(response2.into_body(), usize::MAX)
.await
.expect("Failed to convert to bytes");
let embed_response2: EmbedResponse = serde_json::from_slice(&body2).expect("Failed to deserialize");
assert_eq!(embed_response1.embeddings.len(), embed_response2.embeddings.len());
assert_eq!(embed_response1.embeddings[0], embed_response2.embeddings[0]);
}
#[tokio::test]
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
async fn test_embed_different_presets() {
let app = create_router(ExtractionConfig::default());
let request_fast = json!({
"texts": ["Test text"],
"config": {
"model": {
"type": "preset",
"name": "fast"
}
}
});
let response_fast = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_fast).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response_fast.status(), StatusCode::OK);
let body_fast = axum::body::to_bytes(response_fast.into_body(), usize::MAX)
.await
.expect("Operation failed");
let embed_fast: EmbedResponse = serde_json::from_slice(&body_fast).expect("Failed to deserialize");
let request_balanced = json!({
"texts": ["Test text"],
"config": {
"model": {
"type": "preset",
"name": "balanced"
}
}
});
let response_balanced = app
.oneshot(
Request::builder()
.method("POST")
.uri("/embed")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&request_balanced).expect("Operation failed"),
))
.expect("Operation failed"),
)
.await
.expect("Operation failed");
assert_eq!(response_balanced.status(), StatusCode::OK);
let body_balanced = axum::body::to_bytes(response_balanced.into_body(), usize::MAX)
.await
.expect("Operation failed");
let embed_balanced: EmbedResponse = serde_json::from_slice(&body_balanced).expect("Failed to deserialize");
assert_ne!(embed_fast.dimensions, embed_balanced.dimensions);
assert_eq!(embed_fast.model, "fast");
assert_eq!(embed_balanced.model, "balanced");
}