use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum AiResponseMode {
#[default]
Static,
Intelligent,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiResponseConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub mode: AiResponseMode,
pub prompt: Option<String>,
pub context: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
pub schema: Option<Value>,
#[serde(default = "default_true")]
pub cache_enabled: bool,
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> usize {
1024
}
fn default_true() -> bool {
true
}
impl Default for AiResponseConfig {
fn default() -> Self {
Self {
enabled: false,
mode: AiResponseMode::Static,
prompt: None,
context: None,
temperature: default_temperature(),
max_tokens: default_max_tokens(),
schema: None,
cache_enabled: true,
}
}
}
impl AiResponseConfig {
pub fn new(enabled: bool, mode: AiResponseMode, prompt: String) -> Self {
Self {
enabled,
mode,
prompt: Some(prompt),
..Default::default()
}
}
pub fn is_active(&self) -> bool {
self.enabled && self.mode != AiResponseMode::Static && self.prompt.is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct RequestContext {
pub method: String,
pub path: String,
pub path_params: HashMap<String, Value>,
pub query_params: HashMap<String, Value>,
pub headers: HashMap<String, Value>,
pub body: Option<Value>,
pub multipart_fields: HashMap<String, Value>,
pub multipart_files: HashMap<String, String>,
}
impl RequestContext {
pub fn new(method: String, path: String) -> Self {
Self {
method,
path,
..Default::default()
}
}
pub fn with_path_params(mut self, params: HashMap<String, Value>) -> Self {
self.path_params = params;
self
}
pub fn with_query_params(mut self, params: HashMap<String, Value>) -> Self {
self.query_params = params;
self
}
pub fn with_headers(mut self, headers: HashMap<String, Value>) -> Self {
self.headers = headers;
self
}
pub fn with_body(mut self, body: Value) -> Self {
self.body = Some(body);
self
}
pub fn with_multipart_fields(mut self, fields: HashMap<String, Value>) -> Self {
self.multipart_fields = fields;
self
}
pub fn with_multipart_files(mut self, files: HashMap<String, String>) -> Self {
self.multipart_files = files;
self
}
}
#[deprecated(note = "Use mockforge_template_expansion::expand_prompt_template instead")]
pub fn expand_prompt_template(template: &str, context: &RequestContext) -> String {
let mut result = template.to_string();
result = result
.replace("{{request.query.", "{{query.")
.replace("{{request.path.", "{{path.")
.replace("{{request.headers.", "{{headers.")
.replace("{{request.body.", "{{body.")
.replace("{{request.method}}", "{{method}}")
.replace("{{request.path}}", "{{path}}");
result = result.replace("{{method}}", &context.method);
result = result.replace("{{path}}", &context.path);
for (key, val) in &context.path_params {
let placeholder = format!("{{{{path.{key}}}}}");
let replacement = value_to_string(val);
result = result.replace(&placeholder, &replacement);
}
for (key, val) in &context.query_params {
let placeholder = format!("{{{{query.{key}}}}}");
let replacement = value_to_string(val);
result = result.replace(&placeholder, &replacement);
}
for (key, val) in &context.headers {
let placeholder = format!("{{{{headers.{key}}}}}");
let replacement = value_to_string(val);
result = result.replace(&placeholder, &replacement);
}
if let Some(body) = &context.body {
if let Some(obj) = body.as_object() {
for (key, val) in obj {
let placeholder = format!("{{{{body.{key}}}}}");
let replacement = value_to_string(val);
result = result.replace(&placeholder, &replacement);
}
}
}
for (key, val) in &context.multipart_fields {
let placeholder = format!("{{{{multipart.{key}}}}}");
let replacement = value_to_string(val);
result = result.replace(&placeholder, &replacement);
}
result
}
fn value_to_string(val: &Value) -> String {
match val {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => "null".to_string(),
_ => serde_json::to_string(val).unwrap_or_default(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockforge_template_expansion::{
expand_prompt_template, RequestContext as TemplateRequestContext,
};
use serde_json::json;
fn to_template_context(context: &RequestContext) -> TemplateRequestContext {
TemplateRequestContext {
method: context.method.clone(),
path: context.path.clone(),
path_params: context.path_params.clone(),
query_params: context.query_params.clone(),
headers: context.headers.clone(),
body: context.body.clone(),
multipart_fields: context.multipart_fields.clone(),
multipart_files: context.multipart_files.clone(),
}
}
#[test]
fn test_ai_response_config_default() {
let config = AiResponseConfig::default();
assert!(!config.enabled);
assert_eq!(config.mode, AiResponseMode::Static);
assert!(!config.is_active());
}
#[test]
fn test_ai_response_config_is_active() {
let config =
AiResponseConfig::new(true, AiResponseMode::Intelligent, "Test prompt".to_string());
assert!(config.is_active());
let config_disabled = AiResponseConfig {
enabled: false,
mode: AiResponseMode::Intelligent,
prompt: Some("Test".to_string()),
..Default::default()
};
assert!(!config_disabled.is_active());
}
#[test]
fn test_request_context_builder() {
let mut path_params = HashMap::new();
path_params.insert("id".to_string(), json!("123"));
let context = RequestContext::new("POST".to_string(), "/users/123".to_string())
.with_path_params(path_params)
.with_body(json!({"name": "John"}));
assert_eq!(context.method, "POST");
assert_eq!(context.path, "/users/123");
assert_eq!(context.path_params.get("id"), Some(&json!("123")));
assert_eq!(context.body, Some(json!({"name": "John"})));
}
#[test]
fn test_expand_prompt_template_basic() {
let context = RequestContext::new("GET".to_string(), "/users".to_string());
let template = "Method: {{method}}, Path: {{path}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "Method: GET, Path: /users");
}
#[test]
fn test_expand_prompt_template_body() {
let body = json!({
"message": "Hello",
"user": "Alice"
});
let context = RequestContext::new("POST".to_string(), "/chat".to_string()).with_body(body);
let template = "User {{body.user}} says: {{body.message}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "User Alice says: Hello");
}
#[test]
fn test_expand_prompt_template_path_params() {
let mut path_params = HashMap::new();
path_params.insert("id".to_string(), json!("456"));
path_params.insert("name".to_string(), json!("test"));
let context = RequestContext::new("GET".to_string(), "/users/456".to_string())
.with_path_params(path_params);
let template = "Get user {{path.id}} with name {{path.name}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "Get user 456 with name test");
}
#[test]
fn test_expand_prompt_template_query_params() {
let mut query_params = HashMap::new();
query_params.insert("search".to_string(), json!("term"));
query_params.insert("limit".to_string(), json!(10));
let context = RequestContext::new("GET".to_string(), "/search".to_string())
.with_query_params(query_params);
let template = "Search for {{query.search}} with limit {{query.limit}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "Search for term with limit 10");
}
#[test]
fn test_expand_prompt_template_headers() {
let mut headers = HashMap::new();
headers.insert("user-agent".to_string(), json!("TestClient/1.0"));
let context =
RequestContext::new("GET".to_string(), "/api".to_string()).with_headers(headers);
let template = "Request from {{headers.user-agent}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "Request from TestClient/1.0");
}
#[test]
fn test_expand_prompt_template_complex() {
let mut path_params = HashMap::new();
path_params.insert("id".to_string(), json!("789"));
let mut query_params = HashMap::new();
query_params.insert("format".to_string(), json!("json"));
let body = json!({"action": "update", "value": 42});
let context = RequestContext::new("PUT".to_string(), "/api/items/789".to_string())
.with_path_params(path_params)
.with_query_params(query_params)
.with_body(body);
let template = "{{method}} item {{path.id}} with action {{body.action}} and value {{body.value}} in format {{query.format}}";
let template_context = to_template_context(&context);
let expanded = expand_prompt_template(template, &template_context);
assert_eq!(expanded, "PUT item 789 with action update and value 42 in format json");
}
}