use std::sync::Arc;
use super::cache::{generate_cache_key, ResponseCache};
use super::config::BehaviorModelConfig;
use super::context::StatefulAiContext;
use super::llm_client::LlmClient;
use super::rules::EvaluationContext;
use super::types::{BehaviorRules, LlmGenerationRequest};
use crate::Result;
pub struct BehaviorModel {
config: BehaviorModelConfig,
rules: BehaviorRules,
llm_client: Arc<LlmClient>,
cache: Option<Arc<ResponseCache>>,
}
impl BehaviorModel {
pub fn new(config: BehaviorModelConfig) -> Self {
let rules = config.rules.clone();
let llm_client = Arc::new(LlmClient::new(config.clone()));
let cache = Some(Arc::new(ResponseCache::new(300)));
Self {
config,
rules,
llm_client,
cache,
}
}
pub async fn generate_response(
&self,
method: &str,
path: &str,
request_body: Option<serde_json::Value>,
context: &StatefulAiContext,
) -> Result<serde_json::Value> {
if let Some(ref cache) = self.cache {
let cache_key = generate_cache_key(method, path, request_body.as_ref());
if let Some(cached_response) = cache.get(&cache_key).await {
tracing::debug!("Cache hit for {} {}", method, path);
return Ok(cached_response);
}
}
self.check_consistency_rules(method, path, context).await?;
let prompt = self.build_prompt(method, path, request_body.as_ref(), context).await;
let response = self.generate_with_llm(&prompt).await?;
if let Some(ref cache) = self.cache {
let cache_key = generate_cache_key(method, path, request_body.as_ref());
cache.put(cache_key, response.clone()).await;
}
Ok(response)
}
async fn check_consistency_rules(
&self,
method: &str,
path: &str,
context: &StatefulAiContext,
) -> Result<()> {
let state = context.get_state().await;
let _eval_context =
EvaluationContext::new(method, path).with_session_state(state.state.clone());
let mut rules = self.rules.consistency_rules.clone();
rules.sort_by(|a, b| b.priority.cmp(&a.priority));
for rule in &rules {
if rule.matches(method, path) {
match &rule.action {
super::rules::RuleAction::RequireAuth { message } => {
if !state.state.contains_key("auth_token")
&& !state.state.contains_key("user_id")
{
return Err(crate::Error::internal(message.clone()));
}
}
super::rules::RuleAction::Error { status, message } => {
return Err(crate::Error::internal(format!(
"Rule '{}' failed: {} (status {})",
rule.name, message, status
)));
}
_ => {
}
}
}
}
Ok(())
}
async fn build_prompt(
&self,
method: &str,
path: &str,
request_body: Option<&serde_json::Value>,
context: &StatefulAiContext,
) -> String {
let mut prompt = format!(
"Generate a realistic response for this API request:\n\n\
Method: {}\n\
Path: {}\n",
method, path
);
if let Some(body) = request_body {
prompt.push_str(&format!("Request Body: {}\n", body));
}
let context_summary = context.build_context_summary().await;
prompt.push('\n');
prompt.push_str(&context_summary);
if !self.rules.schemas.is_empty() {
prompt.push_str("\n# Available Schemas\n");
for (name, schema) in &self.rules.schemas {
prompt.push_str(&format!("- {}: {}\n", name, schema));
}
}
prompt.push_str("\nGenerate a realistic JSON response that:\n");
prompt.push_str("1. Matches the request method and path\n");
prompt.push_str("2. Is consistent with the session context\n");
prompt.push_str("3. Conforms to the relevant schema if applicable\n");
prompt.push_str("4. Maintains logical consistency\n");
prompt
}
async fn generate_with_llm(&self, prompt: &str) -> Result<serde_json::Value> {
tracing::debug!("Generating LLM response with prompt ({} chars)", prompt.len());
let request = LlmGenerationRequest::new(self.rules.system_prompt.clone(), prompt)
.with_temperature(self.config.temperature)
.with_max_tokens(self.config.max_tokens);
self.llm_client.generate(&request).await
}
pub fn rules(&self) -> &BehaviorRules {
&self.rules
}
pub fn config(&self) -> &BehaviorModelConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::super::config::IntelligentBehaviorConfig;
use super::*;
#[tokio::test]
async fn test_behavior_model_creation() {
let config = BehaviorModelConfig::default();
let model = BehaviorModel::new(config);
assert!(!model.rules().schemas.is_empty() || model.rules().schemas.is_empty());
}
#[tokio::test]
async fn test_generate_response() {
if std::env::var("OPENAI_API_KEY").is_err() {
eprintln!("Skipping test_generate_response: OPENAI_API_KEY not set");
return;
}
let config = BehaviorModelConfig::default();
let model = BehaviorModel::new(config);
let ai_config = IntelligentBehaviorConfig::default();
let context = StatefulAiContext::new("test_session", ai_config);
let response = model.generate_response("GET", "/api/users", None, &context).await.unwrap();
assert!(response.is_object());
}
}