use crate::config::{SandboxProfile, ModelResourceRequirements};
use async_trait::async_trait;
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use thiserror::Error;
use tokio::process::Command;
use tokio::time::timeout;
#[derive(Debug, Error)]
pub enum SlmRunnerError {
#[error("Model initialization failed: {reason}")]
InitializationFailed { reason: String },
#[error("Model execution failed: {reason}")]
ExecutionFailed { reason: String },
#[error("Resource limit exceeded: {limit_type}")]
ResourceLimitExceeded { limit_type: String },
#[error("Sandbox violation: {violation}")]
SandboxViolation { violation: String },
#[error("Model file not found: {path}")]
ModelFileNotFound { path: String },
#[error("Execution timeout after {seconds} seconds")]
ExecutionTimeout { seconds: u64 },
#[error("Invalid input: {reason}")]
InvalidInput { reason: String },
#[error("IO error: {message}")]
IoError { message: String },
}
#[derive(Debug, Clone)]
pub struct ExecutionOptions {
pub timeout: Option<Duration>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub custom_parameters: HashMap<String, String>,
}
impl Default for ExecutionOptions {
fn default() -> Self {
Self {
timeout: Some(Duration::from_secs(30)),
temperature: Some(0.7),
max_tokens: Some(256),
custom_parameters: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExecutionResult {
pub response: String,
pub metadata: ExecutionMetadata,
}
#[derive(Debug, Clone)]
pub struct ExecutionMetadata {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub execution_time_ms: u64,
pub memory_usage_mb: Option<u64>,
pub limits_hit: Vec<String>,
}
#[async_trait]
pub trait SlmRunner: Send + Sync {
async fn execute(
&self,
prompt: &str,
options: Option<ExecutionOptions>,
) -> Result<ExecutionResult, SlmRunnerError>;
fn get_sandbox_profile(&self) -> &SandboxProfile;
fn get_resource_requirements(&self) -> &ModelResourceRequirements;
async fn health_check(&self) -> Result<(), SlmRunnerError>;
fn get_info(&self) -> RunnerInfo;
}
#[derive(Debug, Clone)]
pub struct RunnerInfo {
pub runner_type: String,
pub model_path: String,
pub capabilities: Vec<String>,
pub version: Option<String>,
}
#[derive(Debug)]
pub struct LocalGgufRunner {
model_path: PathBuf,
sandbox_profile: SandboxProfile,
resource_requirements: ModelResourceRequirements,
llama_cpp_path: PathBuf,
}
impl LocalGgufRunner {
pub async fn new(
model_path: impl Into<PathBuf>,
sandbox_profile: SandboxProfile,
resource_requirements: ModelResourceRequirements,
) -> Result<Self, SlmRunnerError> {
let model_path = model_path.into();
if !model_path.exists() {
return Err(SlmRunnerError::ModelFileNotFound {
path: model_path.display().to_string(),
});
}
let llama_cpp_path = Self::find_llama_cpp_executable().await?;
let runner = Self {
model_path,
sandbox_profile,
resource_requirements,
llama_cpp_path,
};
runner.health_check().await?;
Ok(runner)
}
async fn find_llama_cpp_executable() -> Result<PathBuf, SlmRunnerError> {
let candidate_paths = vec![
"/usr/local/bin/llama-cli",
"/usr/bin/llama-cli",
"/opt/llama.cpp/llama-cli",
"./bin/llama-cli",
];
for path in candidate_paths {
let path_buf = PathBuf::from(path);
if path_buf.exists() {
return Ok(path_buf);
}
}
match Command::new("which").arg("llama-cli").output().await {
Ok(output) if output.status.success() => {
let path_str = String::from_utf8_lossy(&output.stdout);
let trimmed_path = path_str.trim();
Ok(PathBuf::from(trimmed_path))
}
_ => Err(SlmRunnerError::InitializationFailed {
reason: "llama.cpp executable not found".to_string(),
}),
}
}
fn build_command_args(&self, prompt: &str, options: &ExecutionOptions) -> Vec<String> {
let mut args = vec![
"--model".to_string(),
self.model_path.display().to_string(),
"--prompt".to_string(),
prompt.to_string(),
"--no-display-prompt".to_string(),
];
if let Some(temp) = options.temperature {
args.extend(vec!["--temp".to_string(), temp.to_string()]);
}
if let Some(max_tokens) = options.max_tokens {
args.extend(vec!["--n-predict".to_string(), max_tokens.to_string()]);
}
args.extend(vec![
"--threads".to_string(),
self.sandbox_profile.resources.max_cpu_cores.floor().to_string(),
]);
for (key, value) in &options.custom_parameters {
args.extend(vec![format!("--{}", key), value.clone()]);
}
args
}
fn apply_sandbox_constraints(&self, command: &mut Command) {
let memory_limit = self.sandbox_profile.resources.max_memory_mb * 1024 * 1024;
command.env("RLIMIT_AS", memory_limit.to_string());
if let Some(write_path) = self.sandbox_profile.filesystem.write_paths.first() {
if let Ok(path) = std::fs::canonicalize(write_path.trim_end_matches("/*")) {
command.current_dir(path);
}
}
match self.sandbox_profile.network.access_mode {
crate::config::NetworkAccessMode::None => {
command.env("NO_NETWORK", "1");
}
crate::config::NetworkAccessMode::Restricted => {
if !self.sandbox_profile.network.allowed_destinations.is_empty() {
let hosts: Vec<String> = self.sandbox_profile.network.allowed_destinations
.iter()
.map(|dest| dest.host.clone())
.collect();
command.env("ALLOWED_HOSTS", hosts.join(","));
}
}
crate::config::NetworkAccessMode::Full => {
}
}
}
fn validate_execution_constraints(&self, prompt: &str) -> Result<(), SlmRunnerError> {
let estimated_tokens = prompt.len() / 4; if estimated_tokens > 4000 {
return Err(SlmRunnerError::InvalidInput {
reason: "Prompt too long".to_string(),
});
}
self.sandbox_profile.validate()
.map_err(|e| SlmRunnerError::SandboxViolation {
violation: e.to_string(),
})?;
Ok(())
}
}
#[async_trait]
impl SlmRunner for LocalGgufRunner {
async fn execute(
&self,
prompt: &str,
options: Option<ExecutionOptions>,
) -> Result<ExecutionResult, SlmRunnerError> {
let options = options.unwrap_or_default();
let start_time = std::time::Instant::now();
self.validate_execution_constraints(prompt)?;
let args = self.build_command_args(prompt, &options);
let mut command = Command::new(&self.llama_cpp_path);
command.args(&args);
self.apply_sandbox_constraints(&mut command);
let execution_timeout = options.timeout
.unwrap_or_else(|| Duration::from_secs(self.sandbox_profile.process_limits.max_execution_time_seconds));
let output = timeout(execution_timeout, command.output())
.await
.map_err(|_| SlmRunnerError::ExecutionTimeout {
seconds: execution_timeout.as_secs(),
})?
.map_err(|e| SlmRunnerError::ExecutionFailed {
reason: format!("Process execution failed: {}", e),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(SlmRunnerError::ExecutionFailed {
reason: format!("llama.cpp execution failed: {}", stderr),
});
}
let response = String::from_utf8_lossy(&output.stdout).trim().to_string();
let execution_time = start_time.elapsed();
let metadata = ExecutionMetadata {
input_tokens: Some((prompt.len() / 4) as u32), output_tokens: Some((response.len() / 4) as u32), execution_time_ms: execution_time.as_millis() as u64,
memory_usage_mb: None, limits_hit: Vec::new(), };
Ok(ExecutionResult {
response,
metadata,
})
}
fn get_sandbox_profile(&self) -> &SandboxProfile {
&self.sandbox_profile
}
fn get_resource_requirements(&self) -> &ModelResourceRequirements {
&self.resource_requirements
}
async fn health_check(&self) -> Result<(), SlmRunnerError> {
if !self.model_path.exists() {
return Err(SlmRunnerError::ModelFileNotFound {
path: self.model_path.display().to_string(),
});
}
if !self.llama_cpp_path.exists() {
return Err(SlmRunnerError::InitializationFailed {
reason: "llama.cpp executable no longer available".to_string(),
});
}
let test_prompt = "Hello";
let options = ExecutionOptions {
timeout: Some(Duration::from_secs(10)),
temperature: Some(0.1),
max_tokens: Some(1),
custom_parameters: HashMap::new(),
};
match self.execute(test_prompt, Some(options)).await {
Ok(_) => Ok(()),
Err(e) => Err(SlmRunnerError::InitializationFailed {
reason: format!("Health check failed: {}", e),
}),
}
}
fn get_info(&self) -> RunnerInfo {
RunnerInfo {
runner_type: "LocalGgufRunner".to_string(),
model_path: self.model_path.display().to_string(),
capabilities: vec![
"text_generation".to_string(),
"conversation".to_string(),
],
version: Some("1.0.0".to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SandboxProfile;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_resource_requirements() -> ModelResourceRequirements {
ModelResourceRequirements {
min_memory_mb: 512,
preferred_cpu_cores: 1.0,
gpu_requirements: None,
}
}
#[tokio::test]
async fn test_gguf_runner_creation_missing_file() {
let sandbox_profile = SandboxProfile::secure_default();
let resource_requirements = create_test_resource_requirements();
let result = LocalGgufRunner::new(
"/nonexistent/model.gguf",
sandbox_profile,
resource_requirements,
).await;
assert!(matches!(result, Err(SlmRunnerError::ModelFileNotFound { .. })));
}
#[tokio::test]
async fn test_execution_options_default() {
let options = ExecutionOptions::default();
assert_eq!(options.temperature, Some(0.7));
assert_eq!(options.max_tokens, Some(256));
assert!(options.timeout.is_some());
}
#[tokio::test]
async fn test_command_args_building() {
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "dummy model content").unwrap();
let model_path = temp_file.path().to_path_buf();
let sandbox_profile = SandboxProfile::secure_default();
let resource_requirements = create_test_resource_requirements();
let runner = LocalGgufRunner {
model_path: model_path.clone(),
sandbox_profile,
resource_requirements,
llama_cpp_path: PathBuf::from("/fake/llama-cli"), };
let options = ExecutionOptions::default();
let args = runner.build_command_args("test prompt", &options);
assert!(args.contains(&"--model".to_string()));
assert!(args.contains(&model_path.display().to_string()));
assert!(args.contains(&"--prompt".to_string()));
assert!(args.contains(&"test prompt".to_string()));
}
#[test]
fn test_validation_long_prompt() {
let sandbox_profile = SandboxProfile::secure_default();
let resource_requirements = create_test_resource_requirements();
let runner = LocalGgufRunner {
model_path: PathBuf::from("/fake/model.gguf"),
sandbox_profile,
resource_requirements,
llama_cpp_path: PathBuf::from("/fake/llama-cli"),
};
let long_prompt = "a".repeat(20000); let result = runner.validate_execution_constraints(&long_prompt);
assert!(matches!(result, Err(SlmRunnerError::InvalidInput { .. })));
}
}