use std::path::PathBuf;
use ollama_rs::{
generation::completion::request::GenerationRequest,
Ollama
};
use tokio_stream::StreamExt;
pub mod errors;
use errors::KernelError;
pub struct CorellKernel {
ollama: Ollama,
model: String,
storage: sled::Db,
}
impl CorellKernel {
pub fn new(model: &str, storage_path: PathBuf) -> Result<Self, KernelError> {
let ollama = Ollama::default();
let storage = sled::open(storage_path)
.map_err(|e| KernelError::StorageInitError(e.to_string()))?;
Ok(Self {
ollama,
model: model.to_string(),
storage,
})
}
pub fn save_data(&self, key: &str, value: &str) -> Result<(), KernelError> {
self.storage
.insert(key, value)
.map_err(|e| KernelError::StorageOperationError(e.to_string()))?;
self.storage
.flush()
.map_err(|e| KernelError::StorageOperationError(e.to_string()))?;
Ok(())
}
pub fn get_data(&self, key: &str) -> Result<Option<String>, KernelError> {
let result = self.storage
.get(key)
.map_err(|e| KernelError::StorageOperationError(e.to_string()))?;
match result {
Some(ivec) => String::from_utf8(ivec.to_vec())
.map(Some)
.map_err(|e| KernelError::StorageOperationError(e.to_string())),
None => Ok(None),
}
}
pub async fn execute(&self, system_prompt: &str, input_data: &str) -> Result<String, KernelError> {
let final_prompt = format!("System:\n{}\n\nInput:\n{}", system_prompt, input_data);
let request = GenerationRequest::new(self.model.clone(), final_prompt);
let response = self.ollama
.generate(request)
.await
.map_err(|e| KernelError::InferenceError(e.to_string()))?;
Ok(response.response)
}
pub async fn execute_stream(
&self,
system_prompt: &str,
input_data: &str
) -> Result<impl tokio_stream::Stream<Item = Result<String, KernelError>>, KernelError> {
let final_prompt = format!("System:\n{}\n\nInput:\n{}", system_prompt, input_data);
let request = GenerationRequest::new(self.model.clone(), final_prompt);
let stream = self.ollama
.generate_stream(request)
.await
.map_err(|e| KernelError::InferenceError(e.to_string()))?;
Ok(stream.map(|res| {
res.map(|chunks| {
chunks.iter()
.map(|chunk| chunk.response.clone())
.collect::<String>()
})
.map_err(|e| KernelError::InferenceError(e.to_string()))
}))
}
}