use std::sync::Arc;
use axum::{
extract::{Query, State},
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use crate::{
apr::{AprModel, HEADER_SIZE, MAGIC},
audit::{AuditLogger, AuditRecord, InMemoryAuditSink},
cache::{CacheKey, ModelCache},
error::RealizarError,
explain::ShapExplanation,
layers::{Model, ModelConfig},
metrics::MetricsCollector,
registry::ModelRegistry,
tokenizer::BPETokenizer,
};
#[cfg(feature = "cuda")]
pub mod apr_q4k_scheduler;
#[cfg(feature = "cuda")]
pub mod cuda_batch_scheduler;
#[cfg(feature = "cuda")]
pub mod iteration_scheduler;
mod openai_handlers;
pub(crate) use openai_handlers::{
openai_chat_completions_handler, openai_chat_completions_stream_handler, openai_models_handler,
};
mod gpu_handlers;
pub(crate) use gpu_handlers::{
batch_generate_handler, batch_tokenize_handler, generate_handler,
gpu_batch_completions_handler, gpu_status_handler, gpu_warmup_handler, models_handler,
stream_generate_handler, tokenize_handler,
};
#[cfg(feature = "gpu")]
pub use gpu_handlers::{
BatchProcessResult, BatchQueueStats, ContinuousBatchRequest, ContinuousBatchResponse,
GpuBatchRequest, GpuBatchResponse, GpuBatchResult, GpuBatchStats, GpuStatusResponse,
GpuWarmupResponse,
};
#[cfg(feature = "gpu")]
pub use gpu_handlers::{spawn_batch_processor, BatchConfig};
mod realize_handlers;
pub(crate) use realize_handlers::{
clean_chat_output, format_chat_messages, openai_completions_handler, openai_embeddings_handler,
realize_embed_handler, realize_model_handler, realize_reload_handler,
};
#[cfg(feature = "cuda")]
pub(crate) use realize_handlers::{logprobs_handler, perplexity_handler};
pub use realize_handlers::{
CompletionChoice, CompletionRequest, CompletionResponse, ContextWindowConfig,
ContextWindowManager, EmbeddingData, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage,
ModelLineage, ModelMetadataResponse, ReloadRequest, ReloadResponse,
};
mod apr_handlers;
pub(crate) use apr_handlers::{apr_audit_handler, apr_explain_handler, apr_predict_handler};
mod types;
pub use crate::registry::ModelInfo;
pub use types::{default_max_tokens, default_top_k};
#[cfg(test)]
pub(crate) use types::{default_strategy, default_temperature, default_top_p};
pub use types::{
BatchGenerateRequest, BatchGenerateResponse, BatchTokenizeRequest, BatchTokenizeResponse,
ErrorResponse, GenerateRequest, GenerateResponse, HealthResponse, ModelsResponse,
StreamDoneEvent, StreamTokenEvent, TokenizeRequest, TokenizeResponse,
};
#[derive(Clone)]
pub struct AppState {
model: Option<Arc<Model>>,
tokenizer: Option<Arc<BPETokenizer>>,
#[allow(dead_code)]
cache: Option<Arc<ModelCache>>,
#[allow(dead_code)]
cache_key: Option<CacheKey>,
metrics: Arc<MetricsCollector>,
registry: Option<Arc<ModelRegistry>>,
default_model_id: Option<String>,
apr_model: Option<Arc<AprModel>>,
audit_logger: Arc<AuditLogger>,
audit_sink: Arc<InMemoryAuditSink>,
#[cfg(feature = "gpu")]
gpu_model: Option<Arc<std::sync::RwLock<crate::gpu::GpuModel>>>,
quantized_model: Option<Arc<crate::gguf::OwnedQuantizedModel>>,
#[cfg(feature = "gpu")]
cached_model: Option<Arc<crate::gguf::OwnedQuantizedModelCachedSync>>,
#[cfg(feature = "gpu")]
dispatch_metrics: Option<Arc<crate::gguf::DispatchMetrics>>,
#[cfg(feature = "gpu")]
batch_request_tx: Option<tokio::sync::mpsc::Sender<ContinuousBatchRequest>>,
#[cfg(feature = "gpu")]
batch_config: Option<BatchConfig>,
#[cfg(feature = "cuda")]
cuda_model: Option<Arc<std::sync::RwLock<crate::gguf::OwnedQuantizedModelCuda>>>,
#[cfg(feature = "cuda")]
cuda_batch_tx: Option<tokio::sync::mpsc::Sender<cuda_batch_scheduler::CudaBatchRequest>>,
#[cfg(feature = "cuda")]
apr_q4k_tx: Option<tokio::sync::mpsc::Sender<apr_q4k_scheduler::AprQ4kRequest>>,
apr_transformer: Option<Arc<crate::apr_transformer::AprTransformer>>,
#[cfg(feature = "cuda")]
safetensors_cuda_model:
Option<Arc<std::sync::Mutex<crate::safetensors_cuda::SafeTensorsCudaModel>>>,
cached_architecture: Option<String>,
cached_eos_token_id: Option<u32>,
verbose: bool,
trace: bool,
}
fn create_audit_state() -> (Arc<AuditLogger>, Arc<InMemoryAuditSink>) {
let sink = Arc::new(InMemoryAuditSink::new());
let logger = AuditLogger::new(Box::new(InMemorySinkWrapper(sink.clone())))
.with_model_hash("demo-model-hash");
(Arc::new(logger), sink)
}
struct InMemorySinkWrapper(Arc<InMemoryAuditSink>);
impl crate::audit::AuditSink for InMemorySinkWrapper {
fn write_batch(&self, records: &[AuditRecord]) -> Result<(), crate::audit::AuditError> {
self.0.write_batch(records)
}
fn flush(&self) -> Result<(), crate::audit::AuditError> {
self.0.flush()
}
}
include!("mod_app_state_gpu.rs");
include!("mod_create_demo.rs");
include!("router.rs");
include!("dispatch_metrics.rs");