use axum::{extract::State, http::HeaderMap, Json};
use mockforge_registry_core::models::{BYOKConfig, Plan};
use serde::{Deserialize, Serialize};
use crate::{
ai::{
call_llm, check_ai_quota, pick_provider, record_ai_usage, LlmCall, Provider,
ProviderSelection,
},
error::{ApiError, ApiResult},
handlers::settings::decrypt_api_key,
middleware::{resolve_org_context, AuthUser},
AppState,
};
const DEFAULT_TEMPERATURE: f64 = 0.7;
const DEFAULT_MAX_TOKENS: u32 = 1024;
const DEFAULT_SYSTEM_PROMPT: &str = "You are a helpful AI assistant integrated into MockForge.";
#[derive(Debug, Serialize)]
pub struct UsageMeta {
pub provider: &'static str,
pub tokens_used: u64,
pub tokens_used_this_period: i64,
pub tokens_limit: i64,
}
pub(crate) struct PromptInputs {
pub(crate) system: String,
pub(crate) user: String,
pub(crate) model: Option<String>,
pub(crate) temperature: f64,
pub(crate) max_tokens: u32,
}
#[derive(Debug, Deserialize)]
pub struct ChatRequest {
pub prompt: String,
#[serde(default)]
pub system: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub max_tokens: Option<u32>,
}
#[derive(Debug, Serialize)]
pub struct ChatResponse {
pub content: String,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn chat(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<ChatRequest>,
) -> ApiResult<Json<ChatResponse>> {
if request.prompt.trim().is_empty() {
return Err(ApiError::InvalidRequest("prompt must not be empty".into()));
}
let inputs = PromptInputs {
system: request.system.unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.into()),
user: request.prompt,
model: request.model,
temperature: request.temperature.unwrap_or(DEFAULT_TEMPERATURE),
max_tokens: request.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
};
let (content, meta) = run_completion(&state, user_id, &headers, inputs).await?;
Ok(Json(ChatResponse { content, meta }))
}
#[derive(Debug, Deserialize)]
pub struct GenerateOpenApiRequest {
pub description: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct GenerateOpenApiResponse {
pub content: String,
pub spec: Option<serde_json::Value>,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn generate_openapi(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<GenerateOpenApiRequest>,
) -> ApiResult<Json<GenerateOpenApiResponse>> {
if request.description.trim().is_empty() {
return Err(ApiError::InvalidRequest("description must not be empty".into()));
}
let title = request.title.as_deref().unwrap_or("Generated API");
let inputs = PromptInputs {
system: format!(
"You are an expert API designer. Generate a complete, valid OpenAPI 3.0 \
specification in JSON for the API described by the user. Include realistic \
paths, request/response schemas, examples, and at least one error response \
per endpoint. Output ONLY the JSON document, no prose, no markdown fences. \
Use `{title}` as the spec's `info.title` unless a different title is in the \
user's description."
),
user: request.description,
model: request.model,
temperature: 0.2,
max_tokens: 4096,
};
let (content, meta) = run_completion(&state, user_id, &headers, inputs).await?;
let spec = extract_json_payload(&content);
Ok(Json(GenerateOpenApiResponse {
content,
spec,
meta,
}))
}
#[derive(Debug, Deserialize)]
pub struct ExplainRuleRequest {
pub rule_id: String,
pub definition: serde_json::Value,
#[serde(default)]
pub context: Option<String>,
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ExplainRuleResponse {
pub explanation: String,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn explain_rule(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<ExplainRuleRequest>,
) -> ApiResult<Json<ExplainRuleResponse>> {
if request.rule_id.trim().is_empty() {
return Err(ApiError::InvalidRequest("rule_id must not be empty".into()));
}
let definition_str = serde_json::to_string_pretty(&request.definition).map_err(|e| {
ApiError::InvalidRequest(format!("definition must be serializable JSON: {e}"))
})?;
let context_blurb = request
.context
.as_ref()
.map(|c| format!("\n\nContext: {c}"))
.unwrap_or_default();
let inputs = PromptInputs {
system: "You are a senior engineer explaining MockForge mock rules to a junior \
teammate. Be specific: when does this rule fire, what does it return, \
and what edge cases does it cover? Keep it under 200 words and avoid \
marketing language."
.into(),
user: format!(
"Rule id: {id}\n\nDefinition:\n```json\n{def}\n```{ctx}",
id = request.rule_id,
def = definition_str,
ctx = context_blurb,
),
model: request.model,
temperature: 0.4,
max_tokens: 800,
};
let (explanation, meta) = run_completion(&state, user_id, &headers, inputs).await?;
Ok(Json(ExplainRuleResponse { explanation, meta }))
}
#[derive(Debug, Serialize)]
pub struct QuotaResponse {
pub provider: &'static str,
pub tokens_used_this_period: i64,
pub tokens_limit: i64,
pub call_allowed: bool,
}
pub async fn quota(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
) -> ApiResult<Json<QuotaResponse>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
let byok = load_byok_config(&state, org_ctx.org_id).await?;
let is_paid_plan = matches!(org_ctx.org.plan(), Plan::Pro | Plan::Team);
let provider = pick_provider(is_paid_plan, byok);
let selection = provider.selection();
let q = check_ai_quota(&state, &org_ctx.org, selection).await?;
Ok(Json(QuotaResponse {
provider: match selection {
ProviderSelection::Byok => "byok",
ProviderSelection::Platform => "platform",
ProviderSelection::Disabled => "disabled",
},
tokens_used_this_period: q.used,
tokens_limit: q.limit,
call_allowed: q.allowed,
}))
}
#[derive(Debug, Deserialize)]
pub struct VoiceProcessRequest {
pub command: String,
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct VoiceProcessResponse {
pub intent: Option<serde_json::Value>,
pub content: String,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn voice_process(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<VoiceProcessRequest>,
) -> ApiResult<Json<VoiceProcessResponse>> {
if request.command.trim().is_empty() {
return Err(ApiError::InvalidRequest("command must not be empty".into()));
}
let inputs = PromptInputs {
system: "You are a voice-command parser for MockForge. The user describes a mock \
they want to build. Output a single JSON object with fields: intent (one of \
\"create_endpoint\" | \"modify_response\" | \"add_scenario\" | \"unknown\"), \
confidence (0..1), and parsed (any object summarizing the parsed details). \
Return ONLY the JSON, no prose."
.into(),
user: request.command,
model: request.model,
temperature: 0.2,
max_tokens: 800,
};
let (content, meta) = run_completion(&state, user_id, &headers, inputs).await?;
let intent = extract_json_payload(&content);
Ok(Json(VoiceProcessResponse {
intent,
content,
meta,
}))
}
#[derive(Debug, Deserialize)]
pub struct VoiceTranspileHookRequest {
pub description: String,
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct VoiceTranspileHookResponse {
pub hook_source: String,
pub content: String,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn voice_transpile_hook(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<VoiceTranspileHookRequest>,
) -> ApiResult<Json<VoiceTranspileHookResponse>> {
if request.description.trim().is_empty() {
return Err(ApiError::InvalidRequest("description must not be empty".into()));
}
let inputs = PromptInputs {
system: "You write MockForge JavaScript hooks. Given the user's description, output \
a single JS function body (no markdown fences, no surrounding prose) that \
implements the described behavior. The function receives `(req, res, ctx)` \
and may modify `res.status`, `res.headers`, `res.body`."
.into(),
user: request.description,
model: request.model,
temperature: 0.3,
max_tokens: 1024,
};
let (content, meta) = run_completion(&state, user_id, &headers, inputs).await?;
let hook_source = strip_code_fences(&content);
Ok(Json(VoiceTranspileHookResponse {
hook_source,
content,
meta,
}))
}
#[derive(Debug, Deserialize)]
pub struct VoiceCreateScenarioRequest {
pub description: String,
#[serde(default)]
pub workspace_context: Option<String>,
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct VoiceCreateScenarioResponse {
pub scenario: Option<serde_json::Value>,
pub content: String,
#[serde(flatten)]
pub meta: UsageMeta,
}
pub async fn voice_create_workspace_scenario(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<VoiceCreateScenarioRequest>,
) -> ApiResult<Json<VoiceCreateScenarioResponse>> {
if request.description.trim().is_empty() {
return Err(ApiError::InvalidRequest("description must not be empty".into()));
}
let context_blurb = request
.workspace_context
.as_ref()
.map(|c| format!("\n\nWorkspace context: {c}"))
.unwrap_or_default();
let inputs = PromptInputs {
system: "You build MockForge workspace scenarios. Output a single JSON object with \
fields: name, description, steps (array of {action, target, value} objects). \
Return ONLY the JSON, no prose, no markdown fences."
.into(),
user: format!("{}{}", request.description, context_blurb),
model: request.model,
temperature: 0.3,
max_tokens: 1500,
};
let (content, meta) = run_completion(&state, user_id, &headers, inputs).await?;
let scenario = extract_json_payload(&content);
Ok(Json(VoiceCreateScenarioResponse {
scenario,
content,
meta,
}))
}
pub(crate) async fn run_completion(
state: &AppState,
user_id: uuid::Uuid,
headers: &HeaderMap,
prompt: PromptInputs,
) -> ApiResult<(String, UsageMeta)> {
let org_ctx = resolve_org_context(state, user_id, headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
let byok = load_byok_config(state, org_ctx.org_id).await?;
let is_paid_plan = matches!(org_ctx.org.plan(), Plan::Pro | Plan::Team);
let provider = pick_provider(is_paid_plan, byok);
let quota = check_ai_quota(state, &org_ctx.org, provider.selection()).await?;
if !quota.allowed {
return Err(quota.into_error());
}
let selection = provider.selection();
let llm_call = build_llm_call(&provider, prompt)?;
let result = call_llm(llm_call).await?;
let total_tokens = result.total_tokens();
record_ai_usage(state, org_ctx.org_id, selection, total_tokens as i64).await?;
let billed_now = if matches!(selection, ProviderSelection::Platform) {
total_tokens as i64
} else {
0
};
let meta = UsageMeta {
provider: match selection {
ProviderSelection::Byok => "byok",
ProviderSelection::Platform => "platform",
ProviderSelection::Disabled => "disabled", },
tokens_used: total_tokens,
tokens_used_this_period: quota.used + billed_now,
tokens_limit: quota.limit,
};
Ok((result.content, meta))
}
async fn load_byok_config(state: &AppState, org_id: uuid::Uuid) -> ApiResult<Option<BYOKConfig>> {
let setting = state.store.get_org_setting(org_id, "byok").await?;
let Some(setting) = setting else {
return Ok(None);
};
let cfg: BYOKConfig = match serde_json::from_value(setting.setting_value) {
Ok(c) => c,
Err(_) => return Ok(None), };
if !cfg.enabled || cfg.api_key.is_empty() {
return Ok(None);
}
Ok(Some(cfg))
}
fn build_llm_call(provider: &Provider, prompt: PromptInputs) -> ApiResult<LlmCall> {
match provider {
Provider::Disabled => Err(ApiError::ResourceLimitExceeded(
"AI is not available — add a BYOK key or upgrade your plan".into(),
)),
Provider::Byok(cfg) => {
let api_key = decrypt_api_key(&cfg.api_key)?;
Ok(LlmCall {
provider: cfg.provider.clone(),
model: prompt
.model
.or_else(|| cfg.model.clone())
.unwrap_or_else(|| "gpt-4o-mini".into()),
api_key,
base_url: cfg.base_url.clone(),
system: prompt.system,
user: prompt.user,
temperature: prompt.temperature,
max_tokens: prompt.max_tokens,
})
}
Provider::Platform => {
let api_key = std::env::var("MOCKFORGE_PLATFORM_LLM_API_KEY").map_err(|_| {
ApiError::Internal(anyhow::anyhow!(
"Platform LLM not configured: MOCKFORGE_PLATFORM_LLM_API_KEY missing"
))
})?;
let provider_name = std::env::var("MOCKFORGE_PLATFORM_LLM_PROVIDER")
.unwrap_or_else(|_| "openai".into());
let default_model = std::env::var("MOCKFORGE_PLATFORM_LLM_MODEL")
.unwrap_or_else(|_| "gpt-4o-mini".into());
let endpoint = std::env::var("MOCKFORGE_PLATFORM_LLM_ENDPOINT").ok();
Ok(LlmCall {
provider: provider_name,
model: prompt.model.unwrap_or(default_model),
api_key,
base_url: endpoint,
system: prompt.system,
user: prompt.user,
temperature: prompt.temperature,
max_tokens: prompt.max_tokens,
})
}
}
}
pub(crate) fn extract_json_payload(text: &str) -> Option<serde_json::Value> {
let trimmed = text.trim();
let stripped = trimmed
.strip_prefix("```json")
.or_else(|| trimmed.strip_prefix("```"))
.map(|s| s.trim_start())
.unwrap_or(trimmed);
let stripped = stripped.strip_suffix("```").map(str::trim_end).unwrap_or(stripped);
serde_json::from_str(stripped).ok()
}
fn strip_code_fences(text: &str) -> String {
let trimmed = text.trim();
let after_open = trimmed
.strip_prefix("```javascript")
.or_else(|| trimmed.strip_prefix("```js"))
.or_else(|| trimmed.strip_prefix("```"))
.map(|s| s.trim_start_matches('\n'))
.unwrap_or(trimmed);
let stripped = after_open.strip_suffix("```").map(str::trim_end).unwrap_or(after_open);
stripped.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_json_handles_plain() {
let v = extract_json_payload(r#"{"openapi": "3.0.0"}"#).unwrap();
assert_eq!(v["openapi"], "3.0.0");
}
#[test]
fn extract_json_handles_fenced_block() {
let v = extract_json_payload("```json\n{\"openapi\": \"3.0.0\"}\n```").unwrap();
assert_eq!(v["openapi"], "3.0.0");
}
#[test]
fn extract_json_handles_unfenced_with_whitespace() {
let v = extract_json_payload("\n {\"x\": 1} \n").unwrap();
assert_eq!(v["x"], 1);
}
#[test]
fn extract_json_returns_none_for_prose() {
assert!(extract_json_payload("Sure, here's the spec…").is_none());
}
}