use std::{
collections::HashMap,
convert::Infallible,
net::SocketAddr,
sync::{Arc, Mutex},
};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::{
body::Incoming,
http::{Request, Response, StatusCode},
service::service_fn,
};
use hyper_util::{rt::TokioIo, server::conn::auto::Builder as ConnBuilder};
use serde_json::json;
use tokio::net::TcpListener;
#[derive(Debug, Clone)]
pub struct MockServerConfig {
pub api_key: String,
pub base_url: String,
}
impl Default for MockServerConfig {
fn default() -> Self {
Self {
api_key: "test.12345678901234567890".to_string(),
base_url: "http://127.0.0.1:9876".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct MockServerState {
_config: MockServerConfig,
responses: Arc<Mutex<HashMap<String, serde_json::Value>>>,
}
impl MockServerState {
pub fn new(config: MockServerConfig) -> Self {
Self {
_config: config,
responses: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn register_response(&self, endpoint: &str, response: serde_json::Value) {
let mut responses = self.responses.lock().unwrap();
responses.insert(endpoint.to_string(), response);
}
pub fn get_response(&self, endpoint: &str) -> Option<serde_json::Value> {
let responses = self.responses.lock().unwrap();
responses.get(endpoint).cloned()
}
}
#[allow(dead_code)]
pub async fn start_mock_server(config: MockServerConfig) -> Result<(), Box<dyn std::error::Error>> {
let state = MockServerState::new(config.clone());
let addr: SocketAddr = ([127, 0, 0, 1], 9876).into();
let listener = TcpListener::bind(addr).await?;
println!("Mock server running on http://127.0.0.1:9876");
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let state = state.clone();
tokio::task::spawn(async move {
let service = service_fn(move |req| {
let state = state.clone();
async move { handle_request(req, state).await }
});
if let Err(err) = ConnBuilder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
}
#[allow(dead_code)]
async fn handle_request(
req: Request<Incoming>,
state: MockServerState,
) -> Result<Response<Full<Bytes>>, Infallible> {
let path = req.uri().path();
let method = req.method().as_str();
let auth_header = req
.headers()
.get(hyper::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
if let Some(auth) = auth_header {
let expected_auth = format!("Bearer {}", state._config.api_key);
if auth != expected_auth {
return Ok(create_error_response(
StatusCode::UNAUTHORIZED,
1001,
"Invalid API key",
));
}
} else {
return Ok(create_error_response(
StatusCode::UNAUTHORIZED,
1001,
"Missing authorization header",
));
}
let response = match (method, path) {
("POST", "/api/paas/v4/chat/completions") => handle_chat_completion(req, &state).await,
("POST", "/api/paas/v4/embeddings") => handle_embeddings(req, &state).await,
("GET", _) if path.starts_with("/api/paas/v4/files/") => {
handle_file_retrieval(path, &state).await
},
_ => Ok(create_error_response(
StatusCode::NOT_FOUND,
0,
"Endpoint not found",
)),
};
response
}
#[allow(dead_code)]
async fn handle_chat_completion(
req: Request<Incoming>,
state: &MockServerState,
) -> Result<Response<Full<Bytes>>, Infallible> {
let body = req.collect().await.unwrap().to_bytes();
let _request_body: serde_json::Value = serde_json::from_slice(&body).unwrap_or(json!({}));
let response_body =
if let Some(custom_response) = state.get_response("/api/paas/v4/chat/completions") {
custom_response
} else {
json!({
"id": "chatcmpl-1234567890",
"object": "chat.completion",
"created": 1704067200,
"model": "glm-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mock response from the integration test server."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
})
};
Ok(Response::new(Full::new(Bytes::from(
serde_json::to_string(&response_body).unwrap(),
))))
}
#[allow(dead_code)]
async fn handle_embeddings(
req: Request<Incoming>,
state: &MockServerState,
) -> Result<Response<Full<Bytes>>, Infallible> {
let _body = req.collect().await.unwrap().to_bytes();
let response_body = if let Some(custom_response) = state.get_response("/api/paas/v4/embeddings")
{
custom_response
} else {
json!({
"object": "list",
"data": [{
"object": "embedding",
"embedding": [0.002, -0.002, 0.004, 0.001, -0.003, 0.002],
"index": 0
}],
"model": "embedding-2",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
})
};
Ok(Response::new(Full::new(Bytes::from(
serde_json::to_string(&response_body).unwrap(),
))))
}
#[allow(dead_code)]
async fn handle_file_retrieval(
path: &str,
state: &MockServerState,
) -> Result<Response<Full<Bytes>>, Infallible> {
let response_body = if let Some(custom_response) = state.get_response(path) {
custom_response
} else {
json!({
"id": "file-1234567890",
"object": "file",
"bytes": 1024,
"created_at": 1704067200,
"filename": "test.txt",
"purpose": "assistants"
})
};
Ok(Response::new(Full::new(Bytes::from(
serde_json::to_string(&response_body).unwrap(),
))))
}
#[allow(dead_code)]
fn create_error_response(status: StatusCode, code: u16, message: &str) -> Response<Full<Bytes>> {
let error_body = json!({
"error": {
"code": code,
"message": message
}
});
let mut response = Response::new(Full::new(Bytes::from(
serde_json::to_string(&error_body).unwrap(),
)));
*response.status_mut() = status;
response
}
#[derive(Debug, Clone)]
pub struct MockServerClient {
base_url: String,
}
impl MockServerClient {
pub fn new(base_url: String) -> Self {
Self { base_url }
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn url(&self, endpoint: &str) -> String {
format!("{}{}", self.base_url, endpoint)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_server_config_default() {
let config = MockServerConfig::default();
assert_eq!(config.api_key, "test.12345678901234567890");
assert_eq!(config.base_url, "http://127.0.0.1:9876");
}
#[test]
fn test_mock_server_state_register_response() {
let config = MockServerConfig::default();
let state = MockServerState::new(config);
let response = json!({"test": "data"});
state.register_response("/test", response.clone());
assert_eq!(state.get_response("/test"), Some(response));
}
#[test]
fn test_mock_server_client_url() {
let client = MockServerClient::new("http://127.0.0.1:9876".to_string());
assert_eq!(client.base_url(), "http://127.0.0.1:9876");
assert_eq!(client.url("/api/test"), "http://127.0.0.1:9876/api/test");
}
}