#[cfg(feature = "web")]
use axum::{
extract::Json,
http::{header, HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Router,
};
#[cfg(feature = "web")]
use tower_http::cors::{Any, CorsLayer};
#[cfg(all(feature = "web", feature = "streaming"))]
use futures::Stream;
#[cfg(feature = "web")]
use std::sync::Arc;
use crate::{MullamaError, TokenId};
use serde::{Deserialize, Serialize};
#[cfg(feature = "async")]
use crate::async_support::AsyncModel;
#[cfg(feature = "streaming")]
use crate::streaming::{StreamConfig, TokenStream};
#[cfg(feature = "web")]
#[derive(Clone)]
pub struct AppState {
pub model: AsyncModel,
pub default_config: crate::config::MullamaConfig,
pub metrics: Arc<tokio::sync::RwLock<ApiMetrics>>,
}
#[derive(Debug, Default, Clone, serde::Serialize)]
pub struct ApiMetrics {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub avg_response_time_ms: f64,
pub total_tokens_generated: u64,
}
#[cfg(feature = "web")]
impl AppState {
pub fn builder() -> AppStateBuilder {
AppStateBuilder::new()
}
}
#[cfg(feature = "web")]
pub struct AppStateBuilder {
model_path: Option<String>,
config: crate::config::MullamaConfig,
}
#[cfg(feature = "web")]
impl AppStateBuilder {
pub fn new() -> Self {
Self {
model_path: None,
config: crate::config::MullamaConfig::default(),
}
}
pub fn model_path(mut self, path: impl Into<String>) -> Self {
self.model_path = Some(path.into());
self
}
pub fn config(mut self, config: crate::config::MullamaConfig) -> Self {
self.config = config;
self
}
pub async fn build(self) -> Result<AppState, MullamaError> {
let model_path = self
.model_path
.ok_or_else(|| MullamaError::ConfigError("Model path is required".to_string()))?;
let model = AsyncModel::load(model_path).await?;
Ok(AppState {
model,
default_config: self.config,
metrics: Arc::new(tokio::sync::RwLock::new(ApiMetrics::default())),
})
}
}
#[cfg(feature = "web")]
impl Default for AppStateBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Deserialize)]
pub struct GenerateRequest {
pub prompt: String,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: i32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_repeat_penalty")]
pub repeat_penalty: f32,
#[serde(default)]
pub seed: u32,
#[serde(default)]
pub stream: bool,
}
#[derive(Debug, Deserialize)]
pub struct TokenizeRequest {
pub text: String,
#[serde(default)]
pub add_bos: bool,
#[serde(default)]
pub special: bool,
}
#[derive(Debug, Serialize)]
pub struct GenerateResponse {
pub text: String,
pub tokens_generated: usize,
pub generation_time_ms: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<TokenId>>,
}
#[derive(Debug, Serialize)]
pub struct TokenizeResponse {
pub tokens: Vec<TokenId>,
pub count: usize,
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub model_info: ModelInfo,
pub metrics: ApiMetrics,
}
#[derive(Debug, Serialize)]
pub struct ModelInfo {
pub vocab_size: i32,
pub context_size: i32,
pub embedding_dim: i32,
pub layers: i32,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
pub code: String,
pub status: u16,
}
#[derive(Debug, Serialize)]
pub struct StreamChunk {
pub text: String,
pub token_id: TokenId,
pub position: usize,
pub is_final: bool,
}
fn default_max_tokens() -> usize {
100
}
fn default_temperature() -> f32 {
0.8
}
fn default_top_k() -> i32 {
40
}
fn default_top_p() -> f32 {
0.95
}
fn default_repeat_penalty() -> f32 {
1.1
}
#[cfg(feature = "web")]
pub mod handlers {
use super::*;
use axum::extract::State;
use axum::response::Json;
#[cfg(feature = "streaming")]
use axum::response::Sse;
use std::time::Instant;
pub async fn generate(
State(state): State<AppState>,
Json(request): Json<GenerateRequest>,
) -> Result<Json<GenerateResponse>, AppError> {
let start_time = Instant::now();
{
let mut metrics = state.metrics.write().await;
metrics.total_requests += 1;
}
if request.prompt.is_empty() {
return Err(AppError::BadRequest("Prompt cannot be empty".to_string()));
}
if request.max_tokens == 0 || request.max_tokens > 4096 {
return Err(AppError::BadRequest(
"max_tokens must be between 1 and 4096".to_string(),
));
}
let result = state
.model
.generate_async(&request.prompt, request.max_tokens)
.await
.map_err(|e| AppError::Internal(format!("Generation failed: {}", e)))?;
let generation_time = start_time.elapsed().as_millis() as f64;
{
let mut metrics = state.metrics.write().await;
metrics.successful_requests += 1;
metrics.total_tokens_generated += request.max_tokens as u64;
if metrics.total_requests == 1 {
metrics.avg_response_time_ms = generation_time;
} else {
metrics.avg_response_time_ms = (metrics.avg_response_time_ms
* (metrics.total_requests - 1) as f64
+ generation_time)
/ metrics.total_requests as f64;
}
}
Ok(Json(GenerateResponse {
text: result,
tokens_generated: request.max_tokens, generation_time_ms: generation_time,
token_ids: None,
}))
}
#[cfg(feature = "streaming")]
pub async fn generate_stream(
State(state): State<AppState>,
Json(request): Json<GenerateRequest>,
) -> Result<
Sse<impl Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>>,
AppError,
> {
if request.prompt.is_empty() {
return Err(AppError::BadRequest("Prompt cannot be empty".to_string()));
}
let config = StreamConfig::default()
.max_tokens(request.max_tokens)
.temperature(request.temperature)
.top_k(request.top_k)
.top_p(request.top_p);
let token_stream = TokenStream::new(state.model.clone(), request.prompt, config)
.await
.map_err(|e| AppError::Internal(format!("Failed to create stream: {}", e)))?;
let sse_stream = async_stream::stream! {
let mut stream = token_stream;
use futures::StreamExt;
while let Some(result) = stream.next().await {
match result {
Ok(token_data) => {
let chunk = StreamChunk {
text: token_data.text,
token_id: token_data.token_id,
position: token_data.position,
is_final: token_data.is_final,
};
let data = serde_json::to_string(&chunk).unwrap_or_default();
let event = axum::response::sse::Event::default()
.event("token")
.data(data);
yield Ok(event);
if token_data.is_final {
break;
}
}
Err(e) => {
let error_event = axum::response::sse::Event::default()
.event("error")
.data(format!("{{\"error\": \"{}\"}}", e));
yield Ok(error_event);
break;
}
}
}
};
Ok(Sse::new(sse_stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(1))
.text("keep-alive-text"),
))
}
pub async fn tokenize(
State(state): State<AppState>,
Json(request): Json<TokenizeRequest>,
) -> Result<Json<TokenizeResponse>, AppError> {
if request.text.is_empty() {
return Err(AppError::BadRequest("Text cannot be empty".to_string()));
}
let tokens = state
.model
.model()
.tokenize(&request.text, request.add_bos, request.special)
.map_err(|e| AppError::Internal(format!("Tokenization failed: {}", e)))?;
Ok(Json(TokenizeResponse {
count: tokens.len(),
tokens,
}))
}
pub async fn health(State(state): State<AppState>) -> Json<HealthResponse> {
let model_info = ModelInfo {
vocab_size: state.model.model().vocab_size(),
context_size: state.model.model().n_ctx_train(),
embedding_dim: state.model.model().n_embd(),
layers: state.model.model().n_layer(),
};
let metrics = state.metrics.read().await.clone();
Json(HealthResponse {
status: "healthy".to_string(),
model_info,
metrics,
})
}
pub async fn model_info(State(state): State<AppState>) -> Json<ModelInfo> {
Json(ModelInfo {
vocab_size: state.model.model().vocab_size(),
context_size: state.model.model().n_ctx_train(),
embedding_dim: state.model.model().n_embd(),
layers: state.model.model().n_layer(),
})
}
}
#[cfg(feature = "web")]
pub mod middleware {
use super::*;
use axum::response::Response;
use std::time::Duration;
use tower::timeout::TimeoutLayer;
pub fn cors() -> CorsLayer {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
}
pub fn timeout() -> TimeoutLayer {
TimeoutLayer::new(Duration::from_secs(300)) }
pub async fn logging(
request: axum::extract::Request,
next: Next,
) -> Result<Response, StatusCode> {
let start = std::time::Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
let duration = start.elapsed();
println!(
"{} {} - {} - {:?}",
method,
uri,
response.status(),
duration
);
Ok(response)
}
pub async fn rate_limit(
request: axum::extract::Request,
next: Next,
) -> Result<Response, StatusCode> {
Ok(next.run(request).await)
}
}
#[cfg(feature = "web")]
#[derive(Debug)]
pub enum AppError {
BadRequest(String),
Internal(String),
NotFound(String),
Unauthorized(String),
}
#[cfg(feature = "web")]
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, code, message) = match self {
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "BAD_REQUEST", msg),
AppError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, "INTERNAL_ERROR", msg),
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, "NOT_FOUND", msg),
AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg),
};
let error_response = ErrorResponse {
error: message,
code: code.to_string(),
status: status.as_u16(),
};
(status, Json(error_response)).into_response()
}
}
#[cfg(feature = "web")]
pub struct RouterBuilder {
state: Option<AppState>,
cors_enabled: bool,
timeout_enabled: bool,
logging_enabled: bool,
rate_limiting_enabled: bool,
}
#[cfg(feature = "web")]
impl RouterBuilder {
pub fn new() -> Self {
Self {
state: None,
cors_enabled: true,
timeout_enabled: true,
logging_enabled: true,
rate_limiting_enabled: false,
}
}
pub fn state(mut self, state: AppState) -> Self {
self.state = Some(state);
self
}
pub fn cors(mut self, enabled: bool) -> Self {
self.cors_enabled = enabled;
self
}
pub fn timeout(mut self, enabled: bool) -> Self {
self.timeout_enabled = enabled;
self
}
pub fn logging(mut self, enabled: bool) -> Self {
self.logging_enabled = enabled;
self
}
pub fn rate_limiting(mut self, enabled: bool) -> Self {
self.rate_limiting_enabled = enabled;
self
}
pub fn build(self) -> Result<Router, &'static str> {
let state = self.state.ok_or("State is required")?;
let mut router = Router::new()
.route("/generate", axum::routing::post(handlers::generate))
.route("/tokenize", axum::routing::post(handlers::tokenize))
.route("/health", axum::routing::get(handlers::health))
.route("/model", axum::routing::get(handlers::model_info));
#[cfg(feature = "streaming")]
{
router = router.route("/stream", axum::routing::post(handlers::generate_stream));
}
if self.cors_enabled {
router = router.layer(middleware::cors());
}
let router = router.with_state(state);
let router = if self.logging_enabled {
router.layer(axum::middleware::from_fn(middleware::logging))
} else {
router
};
let router = if self.rate_limiting_enabled {
router.layer(axum::middleware::from_fn(middleware::rate_limit))
} else {
router
};
Ok(router)
}
}
#[cfg(feature = "web")]
impl Default for RouterBuilder {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
#[cfg(feature = "web")]
pub async fn create_service(
model_path: impl Into<String>,
bind_address: impl Into<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let state = AppState::builder().model_path(model_path).build().await?;
let router = RouterBuilder::new()
.state(state)
.cors(true)
.timeout(true)
.logging(true)
.build()?;
let bind_addr = bind_address.into();
let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
println!("Mullama web service running on http://{}", bind_addr);
axum::serve(listener, router).await?;
Ok(())
}
pub fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
let auth_header = headers.get(header::AUTHORIZATION)?;
let auth_str = auth_header.to_str().ok()?;
auth_str
.strip_prefix("Bearer ")
.map(std::string::ToString::to_string)
}
pub fn validate_api_key_against(api_key: &str, expected: &str) -> bool {
!api_key.is_empty() && api_key == expected
}
#[deprecated(note = "Use validate_api_key_against with an explicit expected key")]
pub fn validate_api_key(api_key: &str) -> bool {
!api_key.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_request_defaults() {
let json = r#"{"prompt": "Hello"}"#;
let request: GenerateRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.prompt, "Hello");
assert_eq!(request.max_tokens, 100);
assert_eq!(request.temperature, 0.8);
assert_eq!(request.top_k, 40);
assert_eq!(request.top_p, 0.95);
}
#[test]
fn test_error_response_serialization() {
let error = ErrorResponse {
error: "Test error".to_string(),
code: "TEST_ERROR".to_string(),
status: 400,
};
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("Test error"));
assert!(json.contains("TEST_ERROR"));
assert!(json.contains("400"));
}
#[test]
fn test_bearer_token_extraction() {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
"Bearer test-token-123".parse().unwrap(),
);
let token = utils::extract_bearer_token(&headers);
assert_eq!(token, Some("test-token-123".to_string()));
}
}
#[cfg(not(feature = "web"))]
compile_error!("Web support requires the 'web' feature to be enabled");