use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::response::{IntoResponse, Response};
use axum::Json;
use futures::stream::{self, Stream};
use tokio::sync::Mutex;
use crate::model::{InferenceContext, ModelConfig};
use crate::sampling::{Sampler, SamplerConfig};
use crate::tokenizer::Tokenizer;
use crate::Model;
use super::types::*;
pub struct AppState {
pub model: Arc<dyn Model>,
pub tokenizer: Arc<Tokenizer>,
pub config: ModelConfig,
pub model_name: String,
pub inference_lock: Mutex<()>,
}
pub async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok".to_string(),
model: state.model_name.clone(),
context_size: state.config.max_seq_len,
})
}
pub async fn list_models(State(state): State<Arc<AppState>>) -> Json<ModelsResponse> {
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Json(ModelsResponse {
object: "list".to_string(),
data: vec![ModelInfo {
id: state.model_name.clone(),
object: "model".to_string(),
created,
owned_by: "llama-rs".to_string(),
}],
})
}
pub async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(request): Json<ChatCompletionRequest>,
) -> Response {
let _lock = state.inference_lock.lock().await;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let request_id = format!("chatcmpl-{}", created);
let prompt = format_chat_messages(&request.messages);
let sampler_config = SamplerConfig {
temperature: request.temperature,
top_p: request.top_p,
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
..Default::default()
};
match generate_response(
&state,
&prompt,
request.max_tokens,
sampler_config,
request.stop.as_deref(),
)
.await
{
Ok((response_text, prompt_tokens, completion_tokens)) => {
if request.stream {
let stream = create_chat_stream(
request_id,
state.model_name.clone(),
created,
response_text,
);
Sse::new(stream).into_response()
} else {
let response = ChatCompletionResponse {
id: request_id,
object: "chat.completion".to_string(),
created,
model: state.model_name.clone(),
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: Role::Assistant,
content: response_text,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
Json(response).into_response()
}
}
Err(e) => {
let error = ErrorResponse::new(e.to_string(), "server_error");
(StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response()
}
}
}
pub async fn completions(
State(state): State<Arc<AppState>>,
Json(request): Json<CompletionRequest>,
) -> Response {
let _lock = state.inference_lock.lock().await;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let request_id = format!("cmpl-{}", created);
let sampler_config = SamplerConfig {
temperature: request.temperature,
top_p: request.top_p,
..Default::default()
};
match generate_response(
&state,
&request.prompt,
request.max_tokens,
sampler_config,
request.stop.as_deref(),
)
.await
{
Ok((response_text, prompt_tokens, completion_tokens)) => {
let response = CompletionResponse {
id: request_id,
object: "text_completion".to_string(),
created,
model: state.model_name.clone(),
choices: vec![CompletionChoice {
text: response_text,
index: 0,
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
Json(response).into_response()
}
Err(e) => {
let error = ErrorResponse::new(e.to_string(), "server_error");
(StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response()
}
}
}
fn format_chat_messages(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for (i, msg) in messages.iter().enumerate() {
match msg.role {
Role::System => {
prompt.push_str(&format!(
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
msg.content
));
}
Role::User => {
if i == 0 || matches!(messages.get(i - 1).map(|m| &m.role), Some(Role::System)) {
prompt.push_str(&format!("{} [/INST]", msg.content));
} else {
prompt.push_str(&format!(" [INST] {} [/INST]", msg.content));
}
}
Role::Assistant => {
prompt.push_str(&format!(" {}", msg.content));
}
}
}
prompt
}
async fn generate_response(
state: &AppState,
prompt: &str,
max_tokens: usize,
sampler_config: SamplerConfig,
_stop_sequences: Option<&[String]>,
) -> Result<(String, usize, usize), Box<dyn std::error::Error + Send + Sync>> {
let backend: Arc<dyn crate::Backend> =
Arc::new(crate::backend::cpu::CpuBackend::new());
let mut ctx = InferenceContext::new(&state.config, backend);
let mut sampler = Sampler::new(sampler_config, state.config.vocab_size);
let prompt_tokens = state.tokenizer.encode(prompt, true)?;
let prompt_len = prompt_tokens.len();
let mut all_tokens = prompt_tokens.clone();
for (i, &token) in prompt_tokens.iter().enumerate() {
if i < state.config.max_seq_len {
let _ = state.model.forward(&[token], &mut ctx);
}
}
let mut response_text = String::new();
let mut completion_tokens = 0;
for _ in 0..max_tokens {
let last_token = *all_tokens.last().unwrap_or(&state.tokenizer.special_tokens.bos_token_id);
let logits = state.model.forward(&[last_token], &mut ctx)?;
let next_token = sampler.sample(&logits, &all_tokens);
if next_token == state.tokenizer.special_tokens.eos_token_id {
break;
}
if let Ok(text) = state.tokenizer.decode(&[next_token]) {
if text.contains("[INST]") || text.contains("</s>") {
break;
}
response_text.push_str(&text);
}
all_tokens.push(next_token);
completion_tokens += 1;
}
Ok((response_text, prompt_len, completion_tokens))
}
fn create_chat_stream(
request_id: String,
model: String,
created: u64,
response_text: String,
) -> impl Stream<Item = Result<Event, std::convert::Infallible>> {
let chunks = vec![
ChatCompletionChunk {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: Some(Role::Assistant),
content: None,
},
finish_reason: None,
}],
},
ChatCompletionChunk {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: None,
content: Some(response_text),
},
finish_reason: None,
}],
},
ChatCompletionChunk {
id: request_id,
object: "chat.completion.chunk".to_string(),
created,
model,
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
},
];
stream::iter(chunks.into_iter().map(|chunk| {
let data = serde_json::to_string(&chunk).unwrap_or_default();
Ok(Event::default().data(data))
}))
}
#[cfg(feature = "rag")]
pub struct RagState {
pub knowledge_bases: tokio::sync::RwLock<HashMap<String, crate::rag::KnowledgeBaseConfig>>,
pub rag_config: crate::rag::RagConfig,
}
#[cfg(feature = "rag")]
impl RagState {
pub fn new(rag_config: crate::rag::RagConfig) -> Self {
Self {
knowledge_bases: tokio::sync::RwLock::new(HashMap::new()),
rag_config,
}
}
}
#[cfg(feature = "rag")]
pub async fn retrieve(
State(rag_state): State<Arc<RagState>>,
Json(request): Json<RetrieveRequest>,
) -> Response {
use crate::rag::{KnowledgeBase, KnowledgeBaseConfig, RetrievalConfig, MetadataFilter};
let kb_config = {
let kbs = rag_state.knowledge_bases.read().await;
kbs.get(&request.knowledge_base_id).cloned().unwrap_or_else(|| {
KnowledgeBaseConfig {
name: request.knowledge_base_id.clone(),
storage: rag_state.rag_config.clone(),
..Default::default()
}
})
};
let kb = match KnowledgeBase::connect(kb_config).await {
Ok(kb) => kb,
Err(e) => {
let error = ErrorResponse::new(
format!("Failed to connect to knowledge base: {}", e),
"knowledge_base_error",
);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response();
}
};
let mut retrieval_config = RetrievalConfig::default();
if let Some(ref config) = request.retrieval_configuration {
if let Some(ref vs_config) = config.vector_search_configuration {
retrieval_config.max_results = vs_config.number_of_results;
if let Some(ref filter) = vs_config.filter {
retrieval_config.filter = convert_filter(filter);
}
}
}
match kb.retrieve(&request.query, Some(retrieval_config)).await {
Ok(response) => {
let results: Vec<RetrievalResult> = response.chunks.into_iter().map(|chunk| {
RetrievalResult {
content: RetrievalResultContent {
text: chunk.content,
},
location: RetrievalResultLocation {
location_type: "CUSTOM".to_string(),
s3_location: None,
custom_location: Some(CustomLocation {
uri: chunk.source.uri,
}),
},
score: chunk.score,
metadata: chunk.metadata,
}
}).collect();
Json(RetrieveResponse {
retrieval_results: results,
next_token: None,
}).into_response()
}
Err(e) => {
let error = ErrorResponse::new(
format!("Retrieval failed: {}", e),
"retrieval_error",
);
(StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response()
}
}
}
#[cfg(feature = "rag")]
pub async fn retrieve_and_generate(
State((app_state, rag_state)): State<(Arc<AppState>, Arc<RagState>)>,
Json(request): Json<RetrieveAndGenerateRequest>,
) -> Response {
use crate::rag::{KnowledgeBase, KnowledgeBaseConfig, RetrievalConfig};
let kb_id = &request.retrieve_and_generate_configuration.knowledge_base_configuration.knowledge_base_id;
let kb_config = {
let kbs = rag_state.knowledge_bases.read().await;
kbs.get(kb_id).cloned().unwrap_or_else(|| {
KnowledgeBaseConfig {
name: kb_id.clone(),
storage: rag_state.rag_config.clone(),
..Default::default()
}
})
};
let kb = match KnowledgeBase::connect(kb_config).await {
Ok(kb) => kb,
Err(e) => {
let error = ErrorResponse::new(
format!("Failed to connect to knowledge base: {}", e),
"knowledge_base_error",
);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response();
}
};
let mut retrieval_config = RetrievalConfig::default();
if let Some(ref config) = request.retrieve_and_generate_configuration.knowledge_base_configuration.retrieval_configuration {
if let Some(ref vs_config) = config.vector_search_configuration {
retrieval_config.max_results = vs_config.number_of_results;
}
}
if let Some(ref gen_config) = request.retrieve_and_generate_configuration.knowledge_base_configuration.generation_configuration {
if let Some(ref template) = gen_config.prompt_template {
let converted = template.text_prompt_template
.replace("$query$", "{query}")
.replace("$search_results$", "{context}");
retrieval_config.prompt_template = Some(converted);
}
}
let rag_response = match kb.retrieve_and_generate(&request.input.text, Some(retrieval_config)).await {
Ok(resp) => resp,
Err(e) => {
let error = ErrorResponse::new(
format!("RAG failed: {}", e),
"rag_error",
);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response();
}
};
let (temperature, top_p, max_tokens) = if let Some(ref gen_config) =
request.retrieve_and_generate_configuration.knowledge_base_configuration.generation_configuration
{
if let Some(ref inf_config) = gen_config.inference_config {
if let Some(ref text_config) = inf_config.text_inference_config {
(text_config.temperature, text_config.top_p, text_config.max_tokens)
} else {
(0.7, 0.9, 256)
}
} else {
(0.7, 0.9, 256)
}
} else {
(0.7, 0.9, 256)
};
let _lock = app_state.inference_lock.lock().await;
let sampler_config = SamplerConfig {
temperature,
top_p,
..Default::default()
};
let generated_text = match generate_response(
&app_state,
&rag_response.output,
max_tokens,
sampler_config,
None,
).await {
Ok((text, _, _)) => text,
Err(e) => {
let error = ErrorResponse::new(
format!("Generation failed: {}", e),
"generation_error",
);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response();
}
};
let citations: Vec<Citation> = rag_response.citations.into_iter().map(|c| {
Citation {
generated_response_part: None,
retrieved_references: vec![RetrievedReference {
content: RetrievalResultContent {
text: c.content,
},
location: RetrievalResultLocation {
location_type: "CUSTOM".to_string(),
s3_location: None,
custom_location: Some(CustomLocation {
uri: c.source.uri,
}),
},
metadata: None,
}],
}
}).collect();
Json(RetrieveAndGenerateResponse {
output: RetrieveAndGenerateOutput {
text: generated_text,
},
citations,
session_id: request.session_id,
}).into_response()
}
#[cfg(feature = "rag")]
pub async fn ingest(
State(rag_state): State<Arc<RagState>>,
Json(request): Json<IngestRequest>,
) -> Response {
use crate::rag::{KnowledgeBase, KnowledgeBaseConfig, DataSource};
let kb_config = {
let kbs = rag_state.knowledge_bases.read().await;
kbs.get(&request.knowledge_base_id).cloned().unwrap_or_else(|| {
KnowledgeBaseConfig {
name: request.knowledge_base_id.clone(),
storage: rag_state.rag_config.clone(),
..Default::default()
}
})
};
let kb = match KnowledgeBase::connect(kb_config).await {
Ok(kb) => kb,
Err(e) => {
let error = ErrorResponse::new(
format!("Failed to connect to knowledge base: {}", e),
"knowledge_base_error",
);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response();
}
};
let mut total_docs = 0;
let mut total_chunks = 0;
let mut failures = Vec::new();
for doc in request.documents {
let source = DataSource::Text {
content: doc.content.text,
source_id: doc.document_id.clone(),
metadata: doc.metadata,
};
match kb.ingest(source).await {
Ok(result) => {
total_docs += result.documents_processed;
total_chunks += result.chunks_created;
for (id, err) in result.failures {
failures.push(IngestFailure {
document_id: id,
error_message: err,
});
}
}
Err(e) => {
failures.push(IngestFailure {
document_id: doc.document_id,
error_message: e.to_string(),
});
}
}
}
Json(IngestResponse {
documents_ingested: total_docs,
chunks_created: total_chunks,
failures,
}).into_response()
}
#[cfg(feature = "rag")]
pub async fn list_knowledge_bases(
State(rag_state): State<Arc<RagState>>,
Json(_request): Json<ListKnowledgeBasesRequest>,
) -> Response {
let kbs = rag_state.knowledge_bases.read().await;
let summaries: Vec<KnowledgeBaseSummary> = kbs.iter().map(|(id, config)| {
KnowledgeBaseSummary {
knowledge_base_id: id.clone(),
name: config.name.clone(),
description: config.description.clone(),
status: "ACTIVE".to_string(),
updated_at: current_timestamp(),
}
}).collect();
Json(ListKnowledgeBasesResponse {
knowledge_base_summaries: summaries,
next_token: None,
}).into_response()
}
#[cfg(feature = "rag")]
pub async fn get_knowledge_base(
State(rag_state): State<Arc<RagState>>,
Path(kb_id): Path<String>,
) -> Response {
use crate::rag::{KnowledgeBase, KnowledgeBaseConfig};
let kb_config = {
let kbs = rag_state.knowledge_bases.read().await;
kbs.get(&kb_id).cloned().unwrap_or_else(|| {
KnowledgeBaseConfig {
name: kb_id.clone(),
storage: rag_state.rag_config.clone(),
..Default::default()
}
})
};
match KnowledgeBase::connect(kb_config.clone()).await {
Ok(kb) => {
match kb.stats().await {
Ok(stats) => {
Json(GetKnowledgeBaseResponse {
knowledge_base: KnowledgeBaseDetail {
knowledge_base_id: kb_id,
name: stats.name,
description: kb_config.description,
status: "ACTIVE".to_string(),
storage_configuration: StorageConfigurationResponse {
storage_type: "PGVECTOR".to_string(),
vector_dimension: stats.embedding_dimension,
},
updated_at: current_timestamp(),
},
}).into_response()
}
Err(e) => {
let error = ErrorResponse::new(
format!("Failed to get stats: {}", e),
"knowledge_base_error",
);
(StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response()
}
}
}
Err(e) => {
let error = ErrorResponse::new(
format!("Knowledge base not found: {}", e),
"not_found",
);
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
}
}
#[cfg(feature = "rag")]
pub async fn delete_knowledge_base(
State(rag_state): State<Arc<RagState>>,
Path(kb_id): Path<String>,
) -> Response {
use crate::rag::{KnowledgeBase, KnowledgeBaseConfig};
let kb_config = {
let mut kbs = rag_state.knowledge_bases.write().await;
kbs.remove(&kb_id).unwrap_or_else(|| {
KnowledgeBaseConfig {
name: kb_id.clone(),
storage: rag_state.rag_config.clone(),
..Default::default()
}
})
};
match KnowledgeBase::connect(kb_config).await {
Ok(kb) => {
match kb.delete().await {
Ok(_) => {
Json(serde_json::json!({
"knowledgeBaseId": kb_id,
"status": "DELETING"
})).into_response()
}
Err(e) => {
let error = ErrorResponse::new(
format!("Failed to delete: {}", e),
"delete_error",
);
(StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response()
}
}
}
Err(e) => {
let error = ErrorResponse::new(
format!("Knowledge base not found: {}", e),
"not_found",
);
(StatusCode::NOT_FOUND, Json(error)).into_response()
}
}
}
#[cfg(feature = "rag")]
fn convert_filter(filter: &RetrievalFilter) -> Option<crate::rag::MetadataFilter> {
use crate::rag::MetadataFilter;
if let Some(ref and_filters) = filter.and_all {
let converted: Vec<_> = and_filters.iter()
.filter_map(|f| convert_filter(f))
.collect();
if !converted.is_empty() {
return Some(MetadataFilter::and(converted));
}
}
if let Some(ref or_filters) = filter.or_all {
let converted: Vec<_> = or_filters.iter()
.filter_map(|f| convert_filter(f))
.collect();
if !converted.is_empty() {
return Some(MetadataFilter::or(converted));
}
}
if let Some(ref cond) = filter.equals {
return Some(MetadataFilter::eq(&cond.key, cond.value.clone()));
}
if let Some(ref cond) = filter.not_equals {
return Some(MetadataFilter::ne(&cond.key, cond.value.clone()));
}
if let Some(ref cond) = filter.greater_than {
return Some(MetadataFilter::gt(&cond.key, cond.value.clone()));
}
if let Some(ref cond) = filter.less_than {
return Some(MetadataFilter::lt(&cond.key, cond.value.clone()));
}
if let Some(ref cond) = filter.string_contains {
if let Some(s) = cond.value.as_str() {
return Some(MetadataFilter::contains(&cond.key, s));
}
}
if let Some(ref cond) = filter.starts_with {
if let Some(s) = cond.value.as_str() {
return Some(MetadataFilter::starts_with(&cond.key, s));
}
}
None
}
#[cfg(feature = "rag")]
fn current_timestamp() -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
format!("{}Z", now)
}