use crate::error::Result;
use std::collections::HashMap;
use std::time::Duration;
pub mod adapters;
mod allocator;
pub mod backend;
mod batch_scheduling;
mod diagnostics;
pub mod executor;
mod metrics;
pub mod mock_backend;
pub mod planner; mod resilience;
pub mod scheduler;
mod simd_ops;
mod streaming_kv;
#[cfg(feature = "cuda")]
pub use scheduler::CudaScheduler;
pub use scheduler::{
AttentionBuffers, BlockWeights, GpuGenerateConfig, GpuModel, GpuModelConfig, WeightType,
};
pub use planner::{
plan_lm_head_path, plan_sampling, BatchPlanner, BlockForwardPlan, GenerationConfig,
GenerationStep, LmHeadPath, SamplingStrategy,
};
pub use allocator::{
blocked_matmul, naive_matmul, prefetch_read, sequential_sum, sum_with_prefetch,
CacheAlignedBuffer, ForwardArena, ScratchBuffer, TensorPool,
};
pub use diagnostics::{
DebugMode, DiagnosticsCollector, DiagnosticsSummary, LogConfig, LogEntry, LogLevel, Logger,
MemoryReport, MemoryTracker, PhaseTimer, RequestCapture, StateDump,
};
pub use resilience::{
BulkheadConfig, BulkheadManager, BulkheadPermit, BulkheadStats, CircuitBreaker, CircuitConfig,
CircuitState, ErrorCategory, RequestType, RetryConfig, RetryDecision, RetryPolicy,
};
pub use simd_ops::{scalar_rope, scalar_softmax, simd_rope, simd_softmax};
pub use streaming_kv::{StreamingKVCache, StreamingKVCacheFp16};
pub use metrics::{
AsyncGpuResult, ComputeBackend, GpuBufferPool, GpuCompute, GpuPoolStats, HealthChecker,
HybridScheduler, InferenceMetrics, ShutdownCoordinator,
};
pub(crate) use metrics::{cpu_matmul, cpu_matmul_transpose_b, cpu_matmul_transposed_simd};
pub(crate) use scheduler::layer_norm_static;
pub use batch_scheduling::{
AllocationId, AsyncRequestQueue, BatchId, InferenceBatchScheduler, InferenceCompletionHandler,
InferenceEventNotifier, Priority, PriorityRequest, PriorityRequestQueue, RequestId,
ResourceTracker, SpeculativeBuffer, TimeoutManager, TokenBatch, TokenRateLimiter,
};
const MAX_GPU_BUFFER_BYTES: usize = 256 * 1024 * 1024;
pub const LARGE_VOCAB_THRESHOLD: usize = 65536;
#[inline]
#[must_use]
pub fn exceeds_gpu_buffer_limit(elements: usize) -> bool {
elements * std::mem::size_of::<f32>() > MAX_GPU_BUFFER_BYTES
}
pub type MatmulOp = (Vec<f32>, Vec<f32>, usize, usize, usize);
#[derive(Debug)]
pub struct ContiguousAttentionBuffer {
data: Vec<f32>,
max_seq_len: usize,
#[allow(dead_code)]
num_heads: usize,
#[allow(dead_code)]
head_dim: usize,
tensor_size: usize,
}
impl ContiguousAttentionBuffer {
#[must_use]
pub fn new(max_seq_len: usize, num_heads: usize, head_dim: usize) -> Self {
let tensor_size = max_seq_len * num_heads * head_dim;
let data = vec![0.0f32; tensor_size * 4];
Self {
data,
max_seq_len,
num_heads,
head_dim,
tensor_size,
}
}
#[must_use]
pub fn is_contiguous(&self) -> bool {
true
}
#[must_use]
pub fn get_views(&self) -> (&[f32], &[f32], &[f32], &[f32]) {
let q_start = 0;
let k_start = self.tensor_size;
let v_start = self.tensor_size * 2;
let o_start = self.tensor_size * 3;
(
&self.data[q_start..k_start],
&self.data[k_start..v_start],
&self.data[v_start..o_start],
&self.data[o_start..],
)
}
pub fn get_views_mut(&mut self) -> (&mut [f32], &mut [f32], &mut [f32], &mut [f32]) {
let tensor_size = self.tensor_size;
let (q, rest) = self.data.split_at_mut(tensor_size);
let (k, rest) = rest.split_at_mut(tensor_size);
let (v, o) = rest.split_at_mut(tensor_size);
(q, k, v, o)
}
pub fn reset(&mut self) {
self.data.fill(0.0);
}
#[must_use]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
}
#[must_use]
pub fn batch_embed(embedding_table: &[f32], tokens: &[usize], hidden_dim: usize) -> Vec<f32> {
if tokens.is_empty() || embedding_table.is_empty() {
return Vec::new();
}
let mut result = Vec::with_capacity(tokens.len() * hidden_dim);
for &token in tokens {
let start_idx = token * hidden_dim;
let end_idx = start_idx + hidden_dim;
if end_idx <= embedding_table.len() {
result.extend_from_slice(&embedding_table[start_idx..end_idx]);
} else {
eprintln!(
"Warning: batch_embed token {} OOB (offset={start_idx}, table_len={}). N-09 escape.",
token, embedding_table.len()
);
result.extend(std::iter::repeat_n(0.0, hidden_dim));
}
}
result
}
#[must_use]
pub fn sequential_ffn(
input: &[f32],
w_up: &[f32],
w_down: &[f32],
hidden_dim: usize,
intermediate_dim: usize,
) -> Vec<f32> {
if input.is_empty() {
return Vec::new();
}
let mut intermediate = vec![0.0f32; intermediate_dim];
for i in 0..intermediate_dim {
let mut sum = 0.0f32;
for j in 0..hidden_dim {
sum += input[j] * w_up[j * intermediate_dim + i];
}
intermediate[i] =
sum * 0.5 * (1.0 + (sum * 0.797_884_5 * (1.0 + 0.044_715 * sum * sum)).tanh());
}
let mut output = vec![0.0f32; hidden_dim];
for i in 0..hidden_dim {
let mut sum = 0.0f32;
for j in 0..intermediate_dim {
sum += intermediate[j] * w_down[j * hidden_dim + i];
}
output[i] = sum;
}
output
}
#[must_use]
pub fn parallel_ffn(
input: &[f32],
w_up: &[f32],
w_down: &[f32],
hidden_dim: usize,
intermediate_dim: usize,
) -> Vec<f32> {
use rayon::prelude::*;
if input.is_empty() {
return Vec::new();
}
let intermediate: Vec<f32> = (0..intermediate_dim)
.map(|i| {
let sum: f32 = (0..hidden_dim)
.map(|j| input[j] * w_up[j * intermediate_dim + i])
.sum();
sum * 0.5 * (1.0 + (sum * 0.797_884_5 * (1.0 + 0.044_715 * sum * sum)).tanh())
})
.collect();
let output: Vec<f32> = (0..hidden_dim)
.into_par_iter()
.map(|i| {
(0..intermediate_dim)
.map(|j| intermediate[j] * w_down[j * hidden_dim + i])
.sum()
})
.collect();
output
}
#[must_use]
pub fn standard_layernorm(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
if input.is_empty() {
return Vec::new();
}
let n = input.len() as f32;
let mean: f32 = input.iter().sum::<f32>() / n;
let variance: f32 = input.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
let std_dev = (variance + eps).sqrt();
input
.iter()
.enumerate()
.map(|(i, &x)| {
let normalized = (x - mean) / std_dev;
normalized * gamma.get(i).copied().unwrap_or(1.0) + beta.get(i).copied().unwrap_or(0.0)
})
.collect()
}
#[must_use]
pub fn fused_layernorm(input: &[f32], gamma: &[f32], beta: &[f32], eps: f32) -> Vec<f32> {
if input.is_empty() {
return Vec::new();
}
let n = input.len();
let mut mean = 0.0f32;
let mut m2 = 0.0f32;
for (i, &x) in input.iter().enumerate() {
let delta = x - mean;
mean += delta / (i + 1) as f32;
let delta2 = x - mean;
m2 += delta * delta2;
}
let variance = m2 / n as f32;
let std_dev = (variance + eps).sqrt();
let inv_std = 1.0 / std_dev;
input
.iter()
.enumerate()
.map(|(i, &x)| {
let normalized = (x - mean) * inv_std;
normalized * gamma.get(i).copied().unwrap_or(1.0) + beta.get(i).copied().unwrap_or(0.0)
})
.collect()
}
#[must_use]
pub fn quantized_dot_q4(block_a: &[u8], block_b: &[u8]) -> f32 {
if block_a.len() < 18 || block_b.len() < 18 {
return 0.0;
}
let scale_a = half::f16::from_le_bytes([block_a[0], block_a[1]]).to_f32();
let scale_b = half::f16::from_le_bytes([block_b[0], block_b[1]]).to_f32();
let mut acc = 0i32;
for i in 0..16 {
let byte_a = block_a[2 + i];
let byte_b = block_b[2 + i];
let a_lo = (byte_a & 0x0F) as i32 - 8;
let a_hi = ((byte_a >> 4) & 0x0F) as i32 - 8;
let b_lo = (byte_b & 0x0F) as i32 - 8;
let b_hi = ((byte_b >> 4) & 0x0F) as i32 - 8;
acc += a_lo * b_lo + a_hi * b_hi;
}
(acc as f32) * scale_a * scale_b
}
include!("mod_quantized_dot.rs");
include!("mod_max_error_recovery.rs");
include!("mod_connection_state.rs");
include!("resource_monitor.rs");