use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize, Error)]
#[error("{kind:?}: {message}")]
pub struct LlmError {
pub kind: LlmErrorKind,
pub message: String,
pub retryable: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LlmErrorKind {
Authentication,
RateLimit,
InvalidRequest,
ModelNotFound,
Network,
ProviderError,
ParseError,
Timeout,
}
impl LlmError {
pub fn new(kind: LlmErrorKind, message: impl Into<String>, retryable: bool) -> Self {
Self {
kind,
message: message.into(),
retryable,
}
}
pub fn auth(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::Authentication, message, false)
}
pub fn rate_limit(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::RateLimit, message, true)
}
pub fn network(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::Network, message, true)
}
pub fn parse(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::ParseError, message, false)
}
pub fn provider(message: impl Into<String>) -> Self {
Self::new(LlmErrorKind::ProviderError, message, false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub prompt: String,
pub system: Option<String>,
pub max_tokens: u32,
pub temperature: f64,
pub stop_sequences: Vec<String>,
}
impl LlmRequest {
#[must_use]
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
system: None,
max_tokens: 1024,
temperature: 0.7,
stop_sequences: Vec::new(),
}
}
#[must_use]
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_stop_sequence(mut self, stop: impl Into<String>) -> Self {
self.stop_sequences.push(stop.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: String,
pub model: String,
pub usage: TokenUsage,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
MaxTokens,
StopSequence,
ContentFilter,
}
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &'static str;
fn model(&self) -> &str;
fn complete(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError>;
fn provenance(&self, request_id: &str) -> String {
format!("{}:{}", self.model(), request_id)
}
}
use crate::agent::Agent;
use crate::context::{ContextKey, ProposedFact};
use crate::effect::AgentEffect;
use crate::validation::encode_proposal;
use std::sync::Arc;
#[derive(Clone)]
pub struct LlmAgentConfig {
pub system_prompt: String,
pub prompt_template: String,
pub prompt_format: crate::prompt::PromptFormat,
pub target_key: ContextKey,
pub dependencies: Vec<ContextKey>,
pub default_confidence: f64,
pub max_tokens: u32,
pub temperature: f64,
pub requirements: Option<crate::model_selection::AgentRequirements>,
}
impl Default for LlmAgentConfig {
fn default() -> Self {
Self {
system_prompt: String::new(),
prompt_template: "{context}".into(),
prompt_format: crate::prompt::PromptFormat::Edn,
target_key: ContextKey::Hypotheses,
dependencies: vec![ContextKey::Seeds],
default_confidence: 0.7,
max_tokens: 1024,
temperature: 0.7,
requirements: None,
}
}
}
pub trait ResponseParser: Send + Sync {
fn parse(&self, response: &LlmResponse, target_key: ContextKey) -> Vec<ProposedFact>;
}
pub struct SimpleParser {
pub id_prefix: String,
pub confidence: f64,
}
impl Default for SimpleParser {
fn default() -> Self {
Self {
id_prefix: "llm".into(),
confidence: 0.7,
}
}
}
impl ResponseParser for SimpleParser {
fn parse(&self, response: &LlmResponse, target_key: ContextKey) -> Vec<ProposedFact> {
let content = response.content.trim();
if content.is_empty() {
return Vec::new();
}
let id = format!("{}-{}", self.id_prefix, uuid_v4_simple());
vec![ProposedFact {
key: target_key,
id,
content: content.to_string(),
confidence: self.confidence,
provenance: response.model.clone(),
}]
}
}
pub struct LlmAgent {
name: String,
provider: Arc<dyn LlmProvider>,
config: LlmAgentConfig,
parser: Arc<dyn ResponseParser>,
full_dependencies: Vec<ContextKey>,
}
impl LlmAgent {
pub fn new(
name: impl Into<String>,
provider: Arc<dyn LlmProvider>,
config: LlmAgentConfig,
) -> Self {
let name_str = name.into();
let mut full_dependencies = config.dependencies.clone();
if !full_dependencies.contains(&config.target_key) {
full_dependencies.push(config.target_key);
}
let parser = Arc::new(SimpleParser {
id_prefix: name_str.clone(),
confidence: 0.7,
});
Self {
name: name_str,
provider,
config,
parser,
full_dependencies,
}
}
}
impl Agent for LlmAgent {
fn name(&self) -> &str {
&self.name
}
fn dependencies(&self) -> &[ContextKey] {
&self.full_dependencies
}
fn accepts(&self, ctx: &dyn crate::ContextView) -> bool {
let has_input = self.config.dependencies.iter().any(|k| ctx.has(*k));
if !has_input {
return false;
}
let my_prefix = format!("{}-", self.name);
!ctx.get(self.config.target_key)
.iter()
.any(|f| f.id.starts_with(&my_prefix))
}
fn execute(&self, _ctx: &dyn crate::ContextView) -> AgentEffect {
let request = LlmRequest::new("prompt") .with_max_tokens(self.config.max_tokens)
.with_temperature(self.config.temperature);
match self.provider.complete(&request) {
Ok(response) => {
let proposals = self.parser.parse(&response, self.config.target_key);
let facts: Vec<_> = proposals.iter().map(encode_proposal).collect();
AgentEffect::with_facts(facts)
}
Err(_) => AgentEffect::empty(),
}
}
}
fn uuid_v4_simple() -> String {
"test".into()
}