use std::fmt;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_stream::{wrappers::ReceiverStream, Stream};
use tonic::Status;
use crate::batching::BatchPriority;
#[derive(Debug, Clone)]
pub struct GrpcConfig {
pub max_message_size: usize,
pub connection_timeout: Duration,
pub request_timeout: Duration,
pub enable_reflection: bool,
pub enable_health: bool,
pub max_concurrent_streams: u32,
pub initial_connection_window_size: u32,
pub initial_stream_window_size: u32,
pub keepalive_interval: Duration,
pub keepalive_timeout: Duration,
}
impl Default for GrpcConfig {
fn default() -> Self {
Self {
max_message_size: 16 * 1024 * 1024, connection_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(300),
enable_reflection: true,
enable_health: true,
max_concurrent_streams: 200,
initial_connection_window_size: 1024 * 1024, initial_stream_window_size: 512 * 1024, keepalive_interval: Duration::from_secs(60),
keepalive_timeout: Duration::from_secs(20),
}
}
}
impl GrpcConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_reflection(mut self, enabled: bool) -> Self {
self.enable_reflection = enabled;
self
}
pub fn with_health(mut self, enabled: bool) -> Self {
self.enable_health = enabled;
self
}
pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
self.max_concurrent_streams = max;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(i32)]
pub enum GrpcPriority {
Unspecified = 0,
Background = 1,
#[default]
Normal = 2,
High = 3,
Critical = 4,
}
impl From<GrpcPriority> for BatchPriority {
fn from(p: GrpcPriority) -> Self {
match p {
GrpcPriority::Unspecified | GrpcPriority::Normal => BatchPriority::Normal,
GrpcPriority::Background => BatchPriority::Background,
GrpcPriority::High => BatchPriority::High,
GrpcPriority::Critical => BatchPriority::Critical,
}
}
}
impl From<i32> for GrpcPriority {
fn from(v: i32) -> Self {
match v {
1 => Self::Background,
2 => Self::Normal,
3 => Self::High,
4 => Self::Critical,
_ => Self::Unspecified,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(i32)]
pub enum Role {
#[default]
Unspecified = 0,
System = 1,
User = 2,
Assistant = 3,
Tool = 4,
}
impl From<i32> for Role {
fn from(v: i32) -> Self {
match v {
1 => Self::System,
2 => Self::User,
3 => Self::Assistant,
4 => Self::Tool,
_ => Self::Unspecified,
}
}
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unspecified => write!(f, "unspecified"),
Self::System => write!(f, "system"),
Self::User => write!(f, "user"),
Self::Assistant => write!(f, "assistant"),
Self::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
impl Usage {
pub fn new(prompt: i32, completion: i32) -> Self {
Self {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: prompt + completion,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
pub name: Option<String>,
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
..Default::default()
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
..Default::default()
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_tokens: Option<i32>,
pub stop: Vec<String>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<i32>,
pub n: Option<i32>,
pub seed: Option<i64>,
pub priority: GrpcPriority,
pub request_id: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ChatChoice {
pub index: i32,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
}
#[derive(Debug, Clone, Default)]
pub struct ChatChoiceDelta {
pub index: i32,
pub delta: ChatMessageDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ChatMessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_tokens: Option<i32>,
pub stop: Vec<String>,
pub priority: GrpcPriority,
pub request_id: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionChoice {
pub index: i32,
pub text: String,
pub finish_reason: String,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<CompletionChoiceDelta>,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionChoiceDelta {
pub index: i32,
pub text: String,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct EmbedRequest {
pub model: String,
pub input: Vec<String>,
pub encoding_format: Option<String>,
pub dimensions: Option<i32>,
pub priority: GrpcPriority,
pub request_id: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct Embedding {
pub object: String,
pub index: i32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct EmbedResponse {
pub object: String,
pub data: Vec<Embedding>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Default)]
pub struct ListModelsRequest {}
#[derive(Debug, Clone, Default)]
pub struct Model {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
pub context_length: Option<i32>,
}
#[derive(Debug, Clone, Default)]
pub struct ListModelsResponse {
pub object: String,
pub data: Vec<Model>,
}
#[derive(Debug, Clone, Default)]
pub struct HealthCheckRequest {}
#[derive(Debug, Clone, Default)]
pub struct ComponentHealth {
pub name: String,
pub status: String,
pub message: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct HealthCheckResponse {
pub status: String,
pub version: String,
pub uptime_seconds: i64,
pub components: Vec<ComponentHealth>,
}
#[derive(Debug, Clone)]
pub enum GrpcError {
InvalidRequest(String),
ModelNotFound(String),
Internal(String),
Unavailable(String),
ResourceExhausted(String),
DeadlineExceeded(String),
Unauthenticated(String),
PermissionDenied(String),
}
impl fmt::Display for GrpcError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidRequest(msg) => write!(f, "Invalid request: {}", msg),
Self::ModelNotFound(msg) => write!(f, "Model not found: {}", msg),
Self::Internal(msg) => write!(f, "Internal error: {}", msg),
Self::Unavailable(msg) => write!(f, "Unavailable: {}", msg),
Self::ResourceExhausted(msg) => write!(f, "Resource exhausted: {}", msg),
Self::DeadlineExceeded(msg) => write!(f, "Deadline exceeded: {}", msg),
Self::Unauthenticated(msg) => write!(f, "Unauthenticated: {}", msg),
Self::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
}
}
}
impl std::error::Error for GrpcError {}
impl From<GrpcError> for Status {
fn from(err: GrpcError) -> Self {
match err {
GrpcError::InvalidRequest(msg) => Status::invalid_argument(msg),
GrpcError::ModelNotFound(msg) => Status::not_found(msg),
GrpcError::Internal(msg) => Status::internal(msg),
GrpcError::Unavailable(msg) => Status::unavailable(msg),
GrpcError::ResourceExhausted(msg) => Status::resource_exhausted(msg),
GrpcError::DeadlineExceeded(msg) => Status::deadline_exceeded(msg),
GrpcError::Unauthenticated(msg) => Status::unauthenticated(msg),
GrpcError::PermissionDenied(msg) => Status::permission_denied(msg),
}
}
}
#[derive(Debug)]
pub struct GrpcMetrics {
requests_total: AtomicU64,
requests_success: AtomicU64,
requests_failed: AtomicU64,
active_streams: AtomicU64,
response_time_ns: AtomicU64,
started_at: Instant,
}
impl GrpcMetrics {
pub fn new() -> Self {
Self {
requests_total: AtomicU64::new(0),
requests_success: AtomicU64::new(0),
requests_failed: AtomicU64::new(0),
active_streams: AtomicU64::new(0),
response_time_ns: AtomicU64::new(0),
started_at: Instant::now(),
}
}
pub fn record_request(&self) {
self.requests_total.fetch_add(1, Ordering::Relaxed);
}
pub fn record_success(&self, duration: Duration) {
self.requests_success.fetch_add(1, Ordering::Relaxed);
self.response_time_ns
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
}
pub fn record_failure(&self) {
self.requests_failed.fetch_add(1, Ordering::Relaxed);
}
pub fn stream_start(&self) {
self.active_streams.fetch_add(1, Ordering::Relaxed);
}
pub fn stream_end(&self) {
self.active_streams.fetch_sub(1, Ordering::Relaxed);
}
pub fn requests_total(&self) -> u64 {
self.requests_total.load(Ordering::Relaxed)
}
pub fn requests_success(&self) -> u64 {
self.requests_success.load(Ordering::Relaxed)
}
pub fn requests_failed(&self) -> u64 {
self.requests_failed.load(Ordering::Relaxed)
}
pub fn active_streams(&self) -> u64 {
self.active_streams.load(Ordering::Relaxed)
}
pub fn uptime(&self) -> Duration {
self.started_at.elapsed()
}
pub fn avg_response_time(&self) -> Duration {
let success = self.requests_success();
let total_ns = self.response_time_ns.load(Ordering::Relaxed);
if success > 0 {
Duration::from_nanos(total_ns / success)
} else {
Duration::ZERO
}
}
pub fn prometheus(&self) -> String {
let mut output = String::new();
output.push_str("# HELP infernum_grpc_requests_total Total gRPC requests\n");
output.push_str("# TYPE infernum_grpc_requests_total counter\n");
output.push_str(&format!(
"infernum_grpc_requests_total {}\n",
self.requests_total()
));
output.push_str("# HELP infernum_grpc_requests_success Successful gRPC requests\n");
output.push_str("# TYPE infernum_grpc_requests_success counter\n");
output.push_str(&format!(
"infernum_grpc_requests_success {}\n",
self.requests_success()
));
output.push_str("# HELP infernum_grpc_requests_failed Failed gRPC requests\n");
output.push_str("# TYPE infernum_grpc_requests_failed counter\n");
output.push_str(&format!(
"infernum_grpc_requests_failed {}\n",
self.requests_failed()
));
output.push_str("# HELP infernum_grpc_active_streams Active gRPC streams\n");
output.push_str("# TYPE infernum_grpc_active_streams gauge\n");
output.push_str(&format!(
"infernum_grpc_active_streams {}\n",
self.active_streams()
));
output.push_str("# HELP infernum_grpc_avg_response_seconds Average response time\n");
output.push_str("# TYPE infernum_grpc_avg_response_seconds gauge\n");
output.push_str(&format!(
"infernum_grpc_avg_response_seconds {:.6}\n",
self.avg_response_time().as_secs_f64()
));
output
}
}
impl Default for GrpcMetrics {
fn default() -> Self {
Self::new()
}
}
#[tonic::async_trait]
pub trait InfernumService: Send + Sync + 'static {
async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, GrpcError>;
async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk, Status>> + Send>>, GrpcError>;
async fn completion(&self, request: CompletionRequest)
-> Result<CompletionResponse, GrpcError>;
async fn completion_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<CompletionChunk, Status>> + Send>>, GrpcError>;
async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse, GrpcError>;
async fn list_models(&self) -> Result<ListModelsResponse, GrpcError>;
async fn health_check(&self) -> Result<HealthCheckResponse, GrpcError>;
}
#[derive(Debug, Default)]
pub struct MockInfernumService {
pub metrics: Arc<GrpcMetrics>,
}
impl MockInfernumService {
pub fn new() -> Self {
Self {
metrics: Arc::new(GrpcMetrics::new()),
}
}
}
#[tonic::async_trait]
impl InfernumService for MockInfernumService {
async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, GrpcError> {
let start = Instant::now();
self.metrics.record_request();
let response = ChatCompletionResponse {
id: format!("inf-chat-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: request.model.clone(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage::assistant("This is a mock response."),
finish_reason: "stop".to_string(),
}],
usage: Usage::new(10, 5),
};
self.metrics.record_success(start.elapsed());
Ok(response)
}
async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk, Status>> + Send>>, GrpcError>
{
self.metrics.record_request();
self.metrics.stream_start();
let (tx, rx) = mpsc::channel(10);
let model = request.model.clone();
let id = format!("inf-chat-{}", uuid::Uuid::new_v4());
let created = chrono::Utc::now().timestamp();
tokio::spawn(async move {
let chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChatChoiceDelta {
index: 0,
delta: ChatMessageDelta {
role: Some(Role::Assistant),
content: None,
},
finish_reason: None,
}],
};
let _ = tx.send(Ok(chunk)).await;
for word in ["This", " is", " a", " mock", " streaming", " response", "."] {
let chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChatChoiceDelta {
index: 0,
delta: ChatMessageDelta {
role: None,
content: Some(word.to_string()),
},
finish_reason: None,
}],
};
let _ = tx.send(Ok(chunk)).await;
tokio::time::sleep(Duration::from_millis(50)).await;
}
let chunk = ChatCompletionChunk {
id,
object: "chat.completion.chunk".to_string(),
created,
model,
choices: vec![ChatChoiceDelta {
index: 0,
delta: ChatMessageDelta::default(),
finish_reason: Some("stop".to_string()),
}],
};
let _ = tx.send(Ok(chunk)).await;
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn completion(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse, GrpcError> {
let start = Instant::now();
self.metrics.record_request();
let response = CompletionResponse {
id: format!("inf-cmpl-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: request.model.clone(),
choices: vec![CompletionChoice {
index: 0,
text: "This is a mock completion.".to_string(),
finish_reason: "stop".to_string(),
}],
usage: Usage::new(10, 5),
};
self.metrics.record_success(start.elapsed());
Ok(response)
}
async fn completion_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<CompletionChunk, Status>> + Send>>, GrpcError>
{
self.metrics.record_request();
self.metrics.stream_start();
let (tx, rx) = mpsc::channel(10);
let model = request.model.clone();
let id = format!("inf-cmpl-{}", uuid::Uuid::new_v4());
let created = chrono::Utc::now().timestamp();
tokio::spawn(async move {
for word in ["This", " is", " a", " mock", " completion", "."] {
let chunk = CompletionChunk {
id: id.clone(),
object: "text_completion".to_string(),
created,
model: model.clone(),
choices: vec![CompletionChoiceDelta {
index: 0,
text: word.to_string(),
finish_reason: None,
}],
};
let _ = tx.send(Ok(chunk)).await;
tokio::time::sleep(Duration::from_millis(50)).await;
}
let chunk = CompletionChunk {
id,
object: "text_completion".to_string(),
created,
model,
choices: vec![CompletionChoiceDelta {
index: 0,
text: String::new(),
finish_reason: Some("stop".to_string()),
}],
};
let _ = tx.send(Ok(chunk)).await;
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse, GrpcError> {
let start = Instant::now();
self.metrics.record_request();
let dims = request.dimensions.unwrap_or(1536) as usize;
let embeddings: Vec<Embedding> = request
.input
.iter()
.enumerate()
.map(|(i, _)| Embedding {
object: "embedding".to_string(),
index: i as i32,
embedding: vec![0.0; dims],
})
.collect();
let response = EmbedResponse {
object: "list".to_string(),
data: embeddings,
model: request.model.clone(),
usage: Usage::new(request.input.len() as i32 * 5, 0),
};
self.metrics.record_success(start.elapsed());
Ok(response)
}
async fn list_models(&self) -> Result<ListModelsResponse, GrpcError> {
self.metrics.record_request();
Ok(ListModelsResponse {
object: "list".to_string(),
data: vec![
Model {
id: "llama-3-8b".to_string(),
object: "model".to_string(),
created: 1700000000,
owned_by: "meta".to_string(),
context_length: Some(8192),
},
Model {
id: "qwen-2.5-32b".to_string(),
object: "model".to_string(),
created: 1700000000,
owned_by: "alibaba".to_string(),
context_length: Some(32768),
},
],
})
}
async fn health_check(&self) -> Result<HealthCheckResponse, GrpcError> {
Ok(HealthCheckResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_seconds: self.metrics.uptime().as_secs() as i64,
components: vec![
ComponentHealth {
name: "inference".to_string(),
status: "healthy".to_string(),
message: None,
},
ComponentHealth {
name: "models".to_string(),
status: "healthy".to_string(),
message: Some("2 models loaded".to_string()),
},
],
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grpc_config_default() {
let config = GrpcConfig::default();
assert_eq!(config.max_message_size, 16 * 1024 * 1024);
assert!(config.enable_reflection);
assert!(config.enable_health);
}
#[test]
fn test_grpc_config_builder() {
let config = GrpcConfig::new()
.with_max_message_size(8 * 1024 * 1024)
.with_reflection(false)
.with_max_concurrent_streams(100);
assert_eq!(config.max_message_size, 8 * 1024 * 1024);
assert!(!config.enable_reflection);
assert_eq!(config.max_concurrent_streams, 100);
}
#[test]
fn test_grpc_priority_conversion() {
assert_eq!(GrpcPriority::from(0), GrpcPriority::Unspecified);
assert_eq!(GrpcPriority::from(1), GrpcPriority::Background);
assert_eq!(GrpcPriority::from(2), GrpcPriority::Normal);
assert_eq!(GrpcPriority::from(3), GrpcPriority::High);
assert_eq!(GrpcPriority::from(4), GrpcPriority::Critical);
}
#[test]
fn test_grpc_priority_to_batch_priority() {
assert_eq!(
BatchPriority::from(GrpcPriority::Background),
BatchPriority::Background
);
assert_eq!(
BatchPriority::from(GrpcPriority::Normal),
BatchPriority::Normal
);
assert_eq!(BatchPriority::from(GrpcPriority::High), BatchPriority::High);
assert_eq!(
BatchPriority::from(GrpcPriority::Critical),
BatchPriority::Critical
);
}
#[test]
fn test_role_display() {
assert_eq!(Role::System.to_string(), "system");
assert_eq!(Role::User.to_string(), "user");
assert_eq!(Role::Assistant.to_string(), "assistant");
assert_eq!(Role::Tool.to_string(), "tool");
}
#[test]
fn test_usage_new() {
let usage = Usage::new(10, 5);
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 5);
assert_eq!(usage.total_tokens, 15);
}
#[test]
fn test_chat_message_constructors() {
let system = ChatMessage::system("You are helpful.");
assert_eq!(system.role, Role::System);
assert_eq!(system.content, "You are helpful.");
let user = ChatMessage::user("Hello!");
assert_eq!(user.role, Role::User);
let assistant = ChatMessage::assistant("Hi there!");
assert_eq!(assistant.role, Role::Assistant);
}
#[test]
fn test_grpc_error_display() {
let err = GrpcError::InvalidRequest("bad input".to_string());
assert!(err.to_string().contains("Invalid request"));
let err = GrpcError::ModelNotFound("gpt-5".to_string());
assert!(err.to_string().contains("Model not found"));
let err = GrpcError::Internal("oops".to_string());
assert!(err.to_string().contains("Internal"));
}
#[test]
fn test_grpc_error_to_status() {
let err = GrpcError::InvalidRequest("test".to_string());
let status: Status = err.into();
assert_eq!(status.code(), tonic::Code::InvalidArgument);
let err = GrpcError::ModelNotFound("test".to_string());
let status: Status = err.into();
assert_eq!(status.code(), tonic::Code::NotFound);
let err = GrpcError::Unauthenticated("test".to_string());
let status: Status = err.into();
assert_eq!(status.code(), tonic::Code::Unauthenticated);
}
#[test]
fn test_grpc_metrics_new() {
let metrics = GrpcMetrics::new();
assert_eq!(metrics.requests_total(), 0);
assert_eq!(metrics.requests_success(), 0);
assert_eq!(metrics.requests_failed(), 0);
assert_eq!(metrics.active_streams(), 0);
}
#[test]
fn test_grpc_metrics_record() {
let metrics = GrpcMetrics::new();
metrics.record_request();
metrics.record_request();
metrics.record_success(Duration::from_millis(100));
metrics.record_failure();
assert_eq!(metrics.requests_total(), 2);
assert_eq!(metrics.requests_success(), 1);
assert_eq!(metrics.requests_failed(), 1);
}
#[test]
fn test_grpc_metrics_streams() {
let metrics = GrpcMetrics::new();
metrics.stream_start();
metrics.stream_start();
assert_eq!(metrics.active_streams(), 2);
metrics.stream_end();
assert_eq!(metrics.active_streams(), 1);
}
#[test]
fn test_grpc_metrics_prometheus() {
let metrics = GrpcMetrics::new();
metrics.record_request();
metrics.record_success(Duration::from_millis(10));
let output = metrics.prometheus();
assert!(output.contains("infernum_grpc_requests_total 1"));
assert!(output.contains("infernum_grpc_requests_success 1"));
assert!(output.contains("infernum_grpc_active_streams"));
}
#[tokio::test]
async fn test_mock_service_chat_completion() {
let service = MockInfernumService::new();
let request = ChatCompletionRequest {
model: "llama-3-8b".to_string(),
messages: vec![ChatMessage::user("Hello!")],
..Default::default()
};
let response = service.chat_completion(request).await.unwrap();
assert!(!response.id.is_empty());
assert_eq!(response.object, "chat.completion");
assert_eq!(response.choices.len(), 1);
assert_eq!(response.choices[0].finish_reason, "stop");
}
#[tokio::test]
async fn test_mock_service_completion() {
let service = MockInfernumService::new();
let request = CompletionRequest {
model: "llama-3-8b".to_string(),
prompt: "Hello".to_string(),
..Default::default()
};
let response = service.completion(request).await.unwrap();
assert!(!response.id.is_empty());
assert_eq!(response.object, "text_completion");
assert_eq!(response.choices.len(), 1);
}
#[tokio::test]
async fn test_mock_service_embed() {
let service = MockInfernumService::new();
let request = EmbedRequest {
model: "llama-3-8b".to_string(),
input: vec!["Hello".to_string(), "World".to_string()],
dimensions: Some(768),
..Default::default()
};
let response = service.embed(request).await.unwrap();
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].embedding.len(), 768);
}
#[tokio::test]
async fn test_mock_service_list_models() {
let service = MockInfernumService::new();
let response = service.list_models().await.unwrap();
assert_eq!(response.object, "list");
assert!(!response.data.is_empty());
}
#[tokio::test]
async fn test_mock_service_health_check() {
let service = MockInfernumService::new();
let response = service.health_check().await.unwrap();
assert_eq!(response.status, "healthy");
assert!(!response.version.is_empty());
assert!(!response.components.is_empty());
}
}