use bitmamba::model;
use anyhow::{Error, Result};
use axum::{
extract::State,
http::StatusCode,
response::{sse::{Event, Sse}, IntoResponse, Json},
routing::{get, post},
Router,
};
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use futures::stream::Stream;
use hf_hub::{api::sync::Api, Repo, RepoType};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokenizers::Tokenizer;
use tower_http::cors::CorsLayer;
use bitmamba::BitMambaStudent;
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(default = "default_max_tokens")]
max_tokens: usize,
#[serde(default = "default_temperature")]
temperature: f64,
#[serde(default)]
stream: bool,
#[serde(default)]
stop: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct CompletionRequest {
model: String,
prompt: String,
#[serde(default = "default_max_tokens")]
max_tokens: usize,
#[serde(default = "default_temperature")]
temperature: f64,
#[serde(default)]
stream: bool,
#[serde(default)]
stop: Option<Vec<String>>,
}
fn default_max_tokens() -> usize { 256 }
fn default_temperature() -> f64 { 0.7 }
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatChoice>,
usage: Usage,
}
#[derive(Debug, Serialize)]
struct ChatChoice {
index: usize,
message: ChatMessage,
finish_reason: String,
}
#[derive(Debug, Serialize)]
struct CompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<CompletionChoice>,
usage: Usage,
}
#[derive(Debug, Serialize)]
struct CompletionChoice {
index: usize,
text: String,
finish_reason: String,
}
#[derive(Debug, Serialize)]
struct Usage {
prompt_tokens: usize,
completion_tokens: usize,
total_tokens: usize,
}
#[derive(Debug, Serialize)]
struct ChatCompletionChunk {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatChunkChoice>,
}
#[derive(Debug, Serialize)]
struct ChatChunkChoice {
index: usize,
delta: ChatDelta,
finish_reason: Option<String>,
}
#[derive(Debug, Serialize)]
struct ChatDelta {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}
#[derive(Debug, Serialize)]
struct ModelList {
object: String,
data: Vec<ModelInfo>,
}
#[derive(Debug, Serialize)]
struct ModelInfo {
id: String,
object: String,
created: i64,
owned_by: String,
}
struct AppState {
model: Mutex<BitMambaStudent>,
tokenizer: Tokenizer,
}
async fn health() -> &'static str {
"OK"
}
async fn list_models() -> Json<ModelList> {
Json(ModelList {
object: "list".to_string(),
data: vec![ModelInfo {
id: "bitmamba-student".to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp(),
owned_by: "rileyseaburg".to_string(),
}],
})
}
async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(req): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let prompt = format_chat_messages(&req.messages);
let tokens = state.tokenizer.encode(prompt.as_str(), true)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let input_ids = tokens.get_ids();
let prompt_tokens = input_ids.len();
if req.stream {
let stream = generate_stream(
state.clone(),
input_ids.to_vec(),
req.max_tokens,
req.temperature,
req.model.clone(),
);
Ok(Sse::new(stream).into_response())
} else {
let model = state.model.lock().await;
let output_ids = model.generate(input_ids, req.max_tokens, req.temperature)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
drop(model);
let new_tokens = &output_ids[prompt_tokens..];
let completion = state.tokenizer.decode(new_tokens, true)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: req.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: completion,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens: new_tokens.len(),
total_tokens: output_ids.len(),
},
};
Ok(Json(response).into_response())
}
}
async fn completions(
State(state): State<Arc<AppState>>,
Json(req): Json<CompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let tokens = state.tokenizer.encode(req.prompt.as_str(), true)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let input_ids = tokens.get_ids();
let prompt_tokens = input_ids.len();
let model = state.model.lock().await;
let output_ids = model.generate(input_ids, req.max_tokens, req.temperature)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
drop(model);
let new_tokens = &output_ids[prompt_tokens..];
let completion = state.tokenizer.decode(new_tokens, true)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let response = CompletionResponse {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: req.model,
choices: vec![CompletionChoice {
index: 0,
text: completion,
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens: new_tokens.len(),
total_tokens: output_ids.len(),
},
};
Ok(Json(response))
}
fn format_chat_messages(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for msg in messages {
match msg.role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n\n", msg.content)),
"user" => prompt.push_str(&format!("User: {}\n\n", msg.content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n\n", msg.content)),
_ => prompt.push_str(&format!("{}: {}\n\n", msg.role, msg.content)),
}
}
prompt.push_str("Assistant: ");
prompt
}
fn generate_stream(
state: Arc<AppState>,
input_ids: Vec<u32>,
max_tokens: usize,
temperature: f64,
model_name: String,
) -> impl Stream<Item = Result<Event, std::convert::Infallible>> {
async_stream::stream! {
let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
let created = chrono::Utc::now().timestamp();
let initial_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_name.clone(),
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
yield Ok(Event::default().data(serde_json::to_string(&initial_chunk).unwrap()));
let model = state.model.lock().await;
let mut token_ids = input_ids.clone();
let device = model.device().clone();
for _ in 0..max_tokens {
let input_tensor = match candle_core::Tensor::new(&token_ids[..], &device) {
Ok(t) => match t.unsqueeze(0) {
Ok(t) => t,
Err(_) => break,
},
Err(_) => break,
};
let logits = match model.forward(&input_tensor) {
Ok(l) => l,
Err(_) => break,
};
let next_token_id = match model::sample(&logits, temperature) {
Ok(t) => t,
Err(_) => break,
};
if next_token_id == 151643 || next_token_id == 151645 {
break;
}
token_ids.push(next_token_id);
let token_text = match state.tokenizer.decode(&[next_token_id], true) {
Ok(t) => t,
Err(_) => continue,
};
let chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_name.clone(),
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: None,
content: Some(token_text),
},
finish_reason: None,
}],
};
yield Ok(Event::default().data(serde_json::to_string(&chunk).unwrap()));
}
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_name,
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
yield Ok(Event::default().data(serde_json::to_string(&final_chunk).unwrap()));
yield Ok(Event::default().data("[DONE]".to_string()));
}
}
#[tokio::main]
async fn main() -> Result<()> {
println!("=== BitMamba OpenAI-Compatible Server ===\n");
let device = Device::Cpu;
let api = Api::new()?;
let repo = api.repo(Repo::new("rileyseaburg/bitmamba-student".to_string(), RepoType::Model));
println!("Downloading model...");
let model_path = repo.get("model.safetensors")?;
let tokenizer_path = repo.get("tokenizer.json")?;
println!("Loading tokenizer...");
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(Error::msg)?;
println!("Loading model...");
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
let model = BitMambaStudent::load(vb, device)?;
let state = Arc::new(AppState {
model: Mutex::new(model),
tokenizer,
});
let app = Router::new()
.route("/health", get(health))
.route("/v1/models", get(list_models))
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.layer(CorsLayer::permissive())
.with_state(state);
let addr = "0.0.0.0:8000";
println!("\nServer running at http://{}", addr);
println!("\nConfigure your client:");
println!(" Base URL: http://localhost:8000/v1");
println!(" Model: bitmamba-student");
println!("\nEndpoints:");
println!(" POST /v1/chat/completions");
println!(" POST /v1/completions");
println!(" GET /v1/models");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}