use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::SystemTime;
use tokio::sync::mpsc;
use crate::batch::{new_batch_store, BatchStore};
use crate::batch_spool::{BatchQueueSender, BatchStore as DiskBatchStore};
use crate::files_store::FilesStore;
use crate::metrics::Metrics;
use crate::queue::{BatchRequest, VocabBytes};
use crate::rate_limit::PerKeyRateLimiter;
use crate::responses_store::ResponseStore;
use crate::router::ModelPool;
use crate::threads::stream::RunEventSender;
use crate::threads::{RunQueueSender, ThreadStore};
use oxillama_runtime::sampling::SamplerConfig;
use oxillama_runtime::{LoadedLora, PrefixCacheConfig, PrefixKvCache};
pub struct AppState {
pub queue: mpsc::Sender<BatchRequest>,
pub model_id: String,
pub loaded_at: u64,
pub default_sampler: SamplerConfig,
pub vocab_bytes: Option<VocabBytes>,
pub hidden_size: usize,
pub metrics: Arc<Metrics>,
pub batch_store: BatchStore,
pub batch_disk_store: Arc<DiskBatchStore>,
pub batch_queue_tx: BatchQueueSender,
pub model_pool: Mutex<ModelPool>,
pub prefix_cache: Arc<Mutex<PrefixKvCache>>,
pub loras: Arc<RwLock<HashMap<String, Arc<LoadedLora>>>>,
pub threads_store: Option<Arc<ThreadStore>>,
pub run_queue_tx: Option<RunQueueSender>,
pub files_store: Option<Arc<FilesStore>>,
pub run_event_tx_broadcast: Option<RunEventSender>,
pub responses_store: Option<Arc<ResponseStore>>,
pub per_key_rate_limiter: Option<Arc<PerKeyRateLimiter>>,
}
impl AppState {
pub fn new(
queue: mpsc::Sender<BatchRequest>,
model_id: String,
default_sampler: SamplerConfig,
vocab_bytes: Option<VocabBytes>,
hidden_size: usize,
) -> Self {
let loaded_at = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let spool_dir = std::env::temp_dir().join("oxillama_batch_spool");
let batch_disk_store = Arc::new(DiskBatchStore::new(spool_dir).unwrap_or_else(|_| {
DiskBatchStore::new(std::env::temp_dir()).expect("fallback spool dir")
}));
let (batch_queue_tx, _) =
tokio::sync::mpsc::channel::<crate::batch_spool::BatchWorkItem>(1);
Self {
queue,
model_id,
loaded_at,
default_sampler,
vocab_bytes,
hidden_size,
metrics: Arc::new(Metrics::new()),
batch_store: new_batch_store(),
batch_disk_store,
batch_queue_tx,
model_pool: Mutex::new(ModelPool::new(4, 0)),
prefix_cache: Arc::new(Mutex::new(PrefixKvCache::new(PrefixCacheConfig::default()))),
loras: Arc::new(RwLock::new(HashMap::new())),
threads_store: None,
run_queue_tx: None,
files_store: None,
run_event_tx_broadcast: None,
responses_store: None,
per_key_rate_limiter: None,
}
}
pub fn with_threads(mut self, store: Arc<ThreadStore>, tx: RunQueueSender) -> Self {
self.threads_store = Some(store);
self.run_queue_tx = Some(tx);
self
}
pub fn with_files(mut self, store: Arc<FilesStore>) -> Self {
self.files_store = Some(store);
self
}
pub fn with_run_event_sender(mut self, tx: RunEventSender) -> Self {
self.run_event_tx_broadcast = Some(tx);
self
}
pub fn with_responses_store(mut self, store: Arc<ResponseStore>) -> Self {
self.responses_store = Some(store);
self
}
pub fn with_per_key_rate_limiter(mut self, limiter: Arc<PerKeyRateLimiter>) -> Self {
self.per_key_rate_limiter = Some(limiter);
self
}
pub fn with_batch_pipeline(
queue: mpsc::Sender<BatchRequest>,
model_id: String,
default_sampler: SamplerConfig,
vocab_bytes: Option<VocabBytes>,
hidden_size: usize,
batch_disk_store: Arc<DiskBatchStore>,
batch_queue_tx: BatchQueueSender,
) -> Self {
let loaded_at = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
queue,
model_id,
loaded_at,
default_sampler,
vocab_bytes,
hidden_size,
metrics: Arc::new(Metrics::new()),
batch_store: new_batch_store(),
batch_disk_store,
batch_queue_tx,
model_pool: Mutex::new(ModelPool::new(4, 0)),
prefix_cache: Arc::new(Mutex::new(PrefixKvCache::new(PrefixCacheConfig::default()))),
loras: Arc::new(RwLock::new(HashMap::new())),
threads_store: None,
run_queue_tx: None,
files_store: None,
run_event_tx_broadcast: None,
responses_store: None,
per_key_rate_limiter: None,
}
}
}