use crate::provider::{
CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
Role, StreamChunk, Usage,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use futures::stream::BoxStream;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct LocalCudaProvider {
model_name: String,
device: Device,
model_cache: Arc<Mutex<Option<ModelCache>>>,
}
struct ModelCache {
model_path: String,
}
impl LocalCudaProvider {
pub fn new(model_name: String) -> Result<Self> {
let device = match Device::new_cuda(0) {
Ok(d) => {
tracing::info!("Using CUDA device for local inference");
d
}
Err(_) => {
tracing::warn!("CUDA not available, using CPU (will be slow)");
Device::Cpu
}
};
Ok(Self {
model_name,
device,
model_cache: Arc::new(Mutex::new(None)),
})
}
pub fn with_model(model_name: String, model_path: String) -> Result<Self> {
let mut provider = Self::new(model_name)?;
provider.model_cache = Arc::new(Mutex::new(Some(ModelCache { model_path })));
Ok(provider)
}
pub fn is_cuda_available() -> bool {
Device::new_cuda(0).is_ok()
}
pub fn device_info() -> String {
match Device::new_cuda(0) {
Ok(d) => format!("CUDA: {}", d),
Err(_) => "CPU only".to_string(),
}
}
}
#[async_trait]
impl Provider for LocalCudaProvider {
fn name(&self) -> &str {
"local_cuda"
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
Ok(vec![ModelInfo {
id: self.model_name.clone(),
name: self.model_name.clone(),
provider: "local_cuda".to_string(),
context_window: 8192,
max_output_tokens: Some(4096),
supports_vision: false,
supports_tools: true,
supports_streaming: true,
input_cost_per_million: Some(0.0), output_cost_per_million: Some(0.0),
}])
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let prompt = Self::format_messages(&request.messages);
tracing::debug!(
model = %self.model_name,
prompt_len = prompt.len(),
"Local CUDA inference request"
);
Err(anyhow!(
"Local CUDA inference requires model implementation. \
Prompt would be: {}... (truncated)",
&prompt[..prompt.len().min(100)]
))
}
async fn complete_stream(
&self,
request: CompletionRequest,
) -> Result<BoxStream<'static, StreamChunk>> {
Err(anyhow!(
"Streaming inference not yet implemented for local_cuda provider"
))
}
}
impl LocalCudaProvider {
fn format_messages(messages: &[Message]) -> String {
let mut prompt = String::new();
for msg in messages {
match msg.role {
Role::System => {
prompt.push_str("System: ");
prompt.push_str(&Self::content_to_string(&msg.content));
prompt.push_str("\n\n");
}
Role::User => {
prompt.push_str("User: ");
prompt.push_str(&Self::content_to_string(&msg.content));
prompt.push_str("\n\n");
}
Role::Assistant => {
prompt.push_str("Assistant: ");
prompt.push_str(&Self::content_to_string(&msg.content));
prompt.push_str("\n\n");
}
Role::Tool => {
prompt.push_str("Tool: ");
prompt.push_str(&Self::content_to_string(&msg.content));
prompt.push_str("\n\n");
}
}
}
prompt.push_str("Assistant: ");
prompt
}
fn content_to_string(parts: &[ContentPart]) -> String {
parts
.iter()
.map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ToolResult { content, .. } => content.clone(),
ContentPart::Thinking { text } => text.clone(),
_ => String::new(),
})
.collect::<Vec<_>>()
.join("\n")
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LocalCudaConfig {
pub model_name: String,
pub model_path: Option<String>,
pub context_window: Option<usize>,
pub max_new_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub repeat_penalty: Option<f32>,
pub cuda_device: Option<usize>,
}
impl Default for LocalCudaConfig {
fn default() -> Self {
Self {
model_name: "qwen2.5-coder-7b".to_string(),
model_path: None,
context_window: Some(8192),
max_new_tokens: Some(4096),
temperature: Some(0.7),
top_p: Some(0.9),
repeat_penalty: Some(1.1),
cuda_device: Some(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_local_cuda_provider_creation() {
let provider = LocalCudaProvider::new("test-model".to_string());
assert!(provider.is_ok());
assert_eq!(provider.unwrap().name(), "local_cuda");
}
#[test]
fn test_cuda_availability_check() {
let _ = LocalCudaProvider::is_cuda_available();
}
#[test]
fn test_format_messages() {
let messages = vec![
Message {
role: Role::System,
content: vec![ContentPart::Text {
text: "You are a helpful assistant.".to_string(),
}],
},
Message {
role: Role::User,
content: vec![ContentPart::Text {
text: "Hello!".to_string(),
}],
},
];
let formatted = LocalCudaProvider::format_messages(&messages);
assert!(formatted.contains("You are a helpful assistant."));
assert!(formatted.contains("Hello!"));
}
}