nabla_cli/enterprise/providers/
types.rs

1// src/enterprise/providers/types.rs
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use super::http::HTTPProvider;
6
7/// Centralized manager for inference providers
8pub struct InferenceManager {
9    pub default_provider: Arc<dyn InferenceProvider>,
10}
11
12impl InferenceManager {
13    pub fn new() -> Self {
14        // Create a default HTTP provider
15        let default_provider = Arc::new(HTTPProvider::new(
16            "http://localhost:11434".to_string(),
17            None,
18            None,
19        ));
20        
21        Self {
22            default_provider,
23        }
24    }
25    
26    /// Get the default inference provider
27    pub fn get_default_provider(&self) -> Arc<dyn InferenceProvider> {
28        self.default_provider.clone()
29    }
30    
31    /// Create a new HTTP provider with custom configuration
32    pub fn create_http_provider(
33        &self,
34        inference_url: String,
35        api_key: Option<String>,
36        provider_token: Option<String>,
37    ) -> Arc<dyn InferenceProvider> {
38        Arc::new(HTTPProvider::new(inference_url, api_key, provider_token))
39    }
40}
41
42#[async_trait]
43pub trait InferenceProvider: Send + Sync {
44    async fn generate(&self, prompt: &str, options: &GenerationOptions) -> Result<GenerationResponse, InferenceError>;
45    #[allow(dead_code)]
46    async fn embed(&self, _text: &str) -> Result<Vec<f32>, InferenceError> {
47        // Default implementation - not all providers need to implement this
48        Err(InferenceError::ServerError("Embedding not supported".to_string()))
49    }
50    #[allow(dead_code)]
51    async fn is_available(&self) -> bool; // Make this async
52}
53
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct GenerationOptions {
56    pub max_tokens: usize,
57    pub temperature: f32,
58    pub top_p: f32,
59    pub stop_sequences: Vec<String>,
60    pub model_path: Option<String>, // For local GGUF files
61    pub hf_repo: Option<String>,    // For remote HF repos
62    pub model: Option<String>,      // For OpenAI-compatible APIs (e.g., Together, OpenAI)
63}
64
65impl Default for GenerationOptions {
66    fn default() -> Self {
67        Self {
68            max_tokens: 512,
69            temperature: 0.7,
70            top_p: 0.9,
71            stop_sequences: vec![],
72            model_path: None,
73            hf_repo: None,
74            model: None,
75        }
76    }
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80pub struct GenerationResponse {
81    pub text: String,
82    pub tokens_used: usize,
83    pub finish_reason: String,
84}
85
86#[derive(Debug, thiserror::Error)]
87pub enum InferenceError {
88    #[allow(dead_code)]
89    #[error("No available inference provider")]
90    NoAvailableProvider,
91    #[error("Server error: {0}")]
92    ServerError(String),
93    #[error("Network error: {0}")]
94    NetworkError(String),
95}