nabla_cli/enterprise/providers/
types.rs1use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use super::http::HTTPProvider;
6
7pub struct InferenceManager {
9 pub default_provider: Arc<dyn InferenceProvider>,
10}
11
12impl InferenceManager {
13 pub fn new() -> Self {
14 let default_provider = Arc::new(HTTPProvider::new(
16 "http://localhost:11434".to_string(),
17 None,
18 None,
19 ));
20
21 Self {
22 default_provider,
23 }
24 }
25
26 pub fn get_default_provider(&self) -> Arc<dyn InferenceProvider> {
28 self.default_provider.clone()
29 }
30
31 pub fn create_http_provider(
33 &self,
34 inference_url: String,
35 api_key: Option<String>,
36 provider_token: Option<String>,
37 ) -> Arc<dyn InferenceProvider> {
38 Arc::new(HTTPProvider::new(inference_url, api_key, provider_token))
39 }
40}
41
42#[async_trait]
43pub trait InferenceProvider: Send + Sync {
44 async fn generate(&self, prompt: &str, options: &GenerationOptions) -> Result<GenerationResponse, InferenceError>;
45 #[allow(dead_code)]
46 async fn embed(&self, _text: &str) -> Result<Vec<f32>, InferenceError> {
47 Err(InferenceError::ServerError("Embedding not supported".to_string()))
49 }
50 #[allow(dead_code)]
51 async fn is_available(&self) -> bool; }
53
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct GenerationOptions {
56 pub max_tokens: usize,
57 pub temperature: f32,
58 pub top_p: f32,
59 pub stop_sequences: Vec<String>,
60 pub model_path: Option<String>, pub hf_repo: Option<String>, pub model: Option<String>, }
64
65impl Default for GenerationOptions {
66 fn default() -> Self {
67 Self {
68 max_tokens: 512,
69 temperature: 0.7,
70 top_p: 0.9,
71 stop_sequences: vec![],
72 model_path: None,
73 hf_repo: None,
74 model: None,
75 }
76 }
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80pub struct GenerationResponse {
81 pub text: String,
82 pub tokens_used: usize,
83 pub finish_reason: String,
84}
85
86#[derive(Debug, thiserror::Error)]
87pub enum InferenceError {
88 #[allow(dead_code)]
89 #[error("No available inference provider")]
90 NoAvailableProvider,
91 #[error("Server error: {0}")]
92 ServerError(String),
93 #[error("Network error: {0}")]
94 NetworkError(String),
95}