Skip to main content

ai_agent/utils/permissions/
yolo_classifier.rs

1// Source: ~/claudecode/openclaudecode/src/utils/permissions/yoloClassifier.ts
2#![allow(dead_code)]
3
4//! YOLO (auto mode) classifier for security decisions.
5//!
6//! Uses an LLM to classify whether agent actions should be allowed or blocked.
7
8use super::bash_classifier::{
9    get_bash_prompt_allow_descriptions, get_bash_prompt_deny_descriptions,
10};
11use super::classifier_shared::{ContentBlock, extract_tool_use_block};
12use crate::types::permissions::{ClassifierUsage, ToolPermissionContext, YoloClassifierResult};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// YOLO classifier tool name.
17pub const YOLO_CLASSIFIER_TOOL_NAME: &str = "classify_result";
18
19/// Transcript block for the classifier.
20#[derive(Debug, Clone)]
21pub enum TranscriptBlock {
22    Text {
23        text: String,
24    },
25    ToolUse {
26        name: String,
27        input: serde_json::Value,
28    },
29}
30
31/// Transcript entry.
32#[derive(Debug, Clone)]
33pub struct TranscriptEntry {
34    pub role: String, // "user" or "assistant"
35    pub content: Vec<TranscriptBlock>,
36}
37
38/// Builds transcript entries from messages.
39pub fn build_transcript_entries(messages: &[serde_json::Value]) -> Vec<TranscriptEntry> {
40    let mut transcript = Vec::new();
41
42    for msg in messages {
43        if let Some(msg_type) = msg.get("type").and_then(|v| v.as_str()) {
44            match msg_type {
45                "user" => {
46                    if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
47                        let text_blocks = extract_text_blocks(content);
48                        if !text_blocks.is_empty() {
49                            transcript.push(TranscriptEntry {
50                                role: "user".to_string(),
51                                content: text_blocks
52                                    .into_iter()
53                                    .map(|t| TranscriptBlock::Text { text: t })
54                                    .collect(),
55                            });
56                        }
57                    }
58                }
59                "assistant" => {
60                    if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
61                        let blocks = extract_tool_use_blocks(content);
62                        if !blocks.is_empty() {
63                            transcript.push(TranscriptEntry {
64                                role: "assistant".to_string(),
65                                content: blocks,
66                            });
67                        }
68                    }
69                }
70                _ => {}
71            }
72        }
73    }
74
75    transcript
76}
77
78fn extract_text_blocks(content: &serde_json::Value) -> Vec<String> {
79    let mut texts = Vec::new();
80    if let Some(s) = content.as_str() {
81        texts.push(s.to_string());
82    } else if let Some(arr) = content.as_array() {
83        for block in arr {
84            if let Some(block_type) = block.get("type").and_then(|v| v.as_str()) {
85                if block_type == "text" {
86                    if let Some(text) = block.get("text").and_then(|v| v.as_str()) {
87                        texts.push(text.to_string());
88                    }
89                }
90            }
91        }
92    }
93    texts
94}
95
96fn extract_tool_use_blocks(content: &serde_json::Value) -> Vec<TranscriptBlock> {
97    let mut blocks = Vec::new();
98    if let Some(arr) = content.as_array() {
99        for block in arr {
100            if let Some(block_type) = block.get("type").and_then(|v| v.as_str()) {
101                if block_type == "tool_use" {
102                    if let (Some(name), Some(input)) = (
103                        block.get("name").and_then(|v| v.as_str()),
104                        block.get("input"),
105                    ) {
106                        blocks.push(TranscriptBlock::ToolUse {
107                            name: name.to_string(),
108                            input: input.clone(),
109                        });
110                    }
111                }
112            }
113        }
114    }
115    blocks
116}
117
118/// Builds a compact transcript string for the classifier.
119pub fn build_transcript_for_classifier(
120    messages: &[serde_json::Value],
121    _tools: &[serde_json::Value],
122) -> String {
123    let entries = build_transcript_entries(messages);
124    let mut result = String::new();
125
126    for entry in entries {
127        for block in entry.content {
128            match block {
129                TranscriptBlock::Text { text } => {
130                    result.push_str(&format!("User: {}\n", text));
131                }
132                TranscriptBlock::ToolUse { name, input } => {
133                    let input_str = serde_json::to_string(&input).unwrap_or_default();
134                    result.push_str(&format!("{} {}\n", name, input_str));
135                }
136            }
137        }
138    }
139
140    result
141}
142
143/// Formats an action for the classifier.
144pub fn format_action_for_classifier(
145    tool_name: &str,
146    tool_input: serde_json::Value,
147) -> TranscriptEntry {
148    TranscriptEntry {
149        role: "assistant".to_string(),
150        content: vec![TranscriptBlock::ToolUse {
151            name: tool_name.to_string(),
152            input: tool_input,
153        }],
154    }
155}
156
157/// YOLO classifier response schema.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct YoloClassifierResponse {
160    pub thinking: String,
161    #[serde(rename = "shouldBlock")]
162    pub should_block: bool,
163    pub reason: String,
164}
165
166/// Builds the YOLO system prompt.
167pub async fn build_yolo_system_prompt(context: &ToolPermissionContext) -> String {
168    let base_prompt = get_base_prompt();
169    let permissions_template = get_permissions_template(context);
170
171    let auto_mode_allow = get_auto_mode_allow(context);
172    let auto_mode_deny = get_auto_mode_deny(context);
173    let auto_mode_env = get_auto_mode_environment(context);
174
175    let mut system_prompt = base_prompt.replace("<permissions_template>", &permissions_template);
176
177    if !auto_mode_allow.is_empty() {
178        system_prompt = system_prompt.replace(
179            "<user_allow_rules_to_replace></user_allow_rules_to_replace>",
180            &auto_mode_allow,
181        );
182    }
183    if !auto_mode_deny.is_empty() {
184        system_prompt = system_prompt.replace(
185            "<user_deny_rules_to_replace></user_deny_rules_to_replace>",
186            &auto_mode_deny,
187        );
188    }
189    if !auto_mode_env.is_empty() {
190        system_prompt = system_prompt.replace(
191            "<user_environment_to_replace></user_environment_to_replace>",
192            &auto_mode_env,
193        );
194    }
195
196    system_prompt
197}
198
199fn get_base_prompt() -> String {
200    // In a full implementation, this would load from a bundled file
201    "You are a security classifier for an AI agent. Evaluate whether actions should be allowed or blocked.".to_string()
202}
203
204fn get_permissions_template(context: &ToolPermissionContext) -> String {
205    // Check if using external permissions
206    if is_using_external_permissions() {
207        get_external_permissions_template()
208    } else {
209        get_anthropic_permissions_template()
210    }
211}
212
213fn is_using_external_permissions() -> bool {
214    std::env::var("USER_TYPE").as_deref() != Ok("ant")
215}
216
217fn get_external_permissions_template() -> String {
218    "## Allow Rules\n<user_allow_rules_to_replace></user_allow_rules_to_replace>\n\n## Deny Rules\n<user_deny_rules_to_replace></user_deny_rules_to_replace>\n\n## Environment\n<user_environment_to_replace></user_environment_to_replace>".to_string()
219}
220
221fn get_anthropic_permissions_template() -> String {
222    "## Allow Rules\n<user_allow_rules_to_replace>- Read-only file access\n- Safe development commands</user_allow_rules_to_replace>\n\n## Deny Rules\n<user_deny_rules_to_replace>- Code execution without explicit permission\n- Network exfiltration attempts\n- Irreversible file deletion</user_deny_rules_to_replace>\n\n## Environment\n<user_environment_to_replace>- Development environment\n- Standard project structure</user_environment_to_replace>".to_string()
223}
224
225fn get_auto_mode_allow(context: &ToolPermissionContext) -> String {
226    let mut allow = Vec::new();
227
228    // Include bash prompt rules for ant builds
229    if std::env::var("USER_TYPE").as_deref() == Ok("ant") && !is_using_external_permissions() {
230        allow.extend(get_bash_prompt_allow_descriptions(&()));
231    }
232
233    allow.sort();
234    allow.dedup();
235
236    allow
237        .iter()
238        .map(|d| format!("- {}", d))
239        .collect::<Vec<_>>()
240        .join("\n")
241}
242
243fn get_auto_mode_deny(context: &ToolPermissionContext) -> String {
244    let mut deny = Vec::new();
245
246    if std::env::var("USER_TYPE").as_deref() == Ok("ant") && !is_using_external_permissions() {
247        deny.extend(get_bash_prompt_deny_descriptions(&()));
248    }
249
250    // PowerShell deny guidance
251    if std::env::var("CLAUDE_CODE_POWERSHELL_AUTO_MODE").as_deref() == Ok("1")
252        || std::env::var("CLAUDE_CODE_POWERSHELL_AUTO_MODE").as_deref() == Ok("true")
253    {
254        deny.extend(get_powershell_deny_guidance());
255    }
256
257    deny.sort();
258    deny.dedup();
259
260    deny.iter()
261        .map(|d| format!("- {}", d))
262        .collect::<Vec<_>>()
263        .join("\n")
264}
265
266fn get_auto_mode_environment(context: &ToolPermissionContext) -> String {
267    let _ = context;
268    String::new()
269}
270
271fn get_powershell_deny_guidance() -> Vec<String> {
272    vec![
273        "PowerShell Download-and-Execute: iex (iwr ...) and similar fall under \"Code from External\".".to_string(),
274        "PowerShell Irreversible Destruction: Remove-Item -Recurse -Force falls under \"Irreversible Local Destruction\".".to_string(),
275        "PowerShell Persistence: modifying $PROFILE or registry Run keys falls under \"Unauthorized Persistence\".".to_string(),
276        "PowerShell Elevation: Start-Process -Verb RunAs falls under \"Security Weaken\".".to_string(),
277    ]
278}
279
280/// Classifies a YOLO action.
281pub async fn classify_yolo_action(
282    _messages: &[serde_json::Value],
283    _action: TranscriptEntry,
284    _tools: &[serde_json::Value],
285    _context: &ToolPermissionContext,
286    _signal: &tokio::sync::oneshot::Receiver<()>,
287) -> YoloClassifierResult {
288    // In a full implementation, this would call the LLM API
289    YoloClassifierResult {
290        thinking: None,
291        should_block: false,
292        reason: "Classifier not available".to_string(),
293        unavailable: Some(true),
294        transcript_too_long: None,
295        model: "unknown".to_string(),
296        usage: None,
297        duration_ms: None,
298        prompt_lengths: None,
299        error_dump_path: None,
300        stage: None,
301        stage1_usage: None,
302        stage1_duration_ms: None,
303        stage1_request_id: None,
304        stage1_msg_id: None,
305        stage2_usage: None,
306        stage2_duration_ms: None,
307        stage2_request_id: None,
308        stage2_msg_id: None,
309    }
310}
311
312/// Gets the auto mode dump directory.
313pub fn get_auto_mode_dump_dir() -> String {
314    let temp = std::env::temp_dir();
315    temp.join("claude-auto-mode").to_string_lossy().to_string()
316}
317
318/// Gets the classifier error dump path.
319pub fn get_auto_mode_classifier_error_dump_path() -> String {
320    let temp = std::env::temp_dir();
321    temp.join("auto-mode-classifier-errors")
322        .to_string_lossy()
323        .to_string()
324}
325
326/// Gets the classifier transcript.
327pub fn get_auto_mode_classifier_transcript() -> Option<String> {
328    // Would read from session state
329    None
330}
331
332/// Checks if JSONL transcript format is enabled.
333pub fn is_jsonl_transcript_enabled_yolo() -> bool {
334    std::env::var("CLAUDE_CODE_JSONL_TRANSCRIPT")
335        .ok()
336        .map(|v| v == "1" || v == "true")
337        .unwrap_or(false)
338}
339
340/// Gets the default external auto mode rules.
341pub fn get_default_external_auto_mode_rules() -> HashMap<String, Vec<String>> {
342    let mut rules = HashMap::new();
343    rules.insert("allow".to_string(), vec![]);
344    rules.insert("soft_deny".to_string(), vec![]);
345    rules.insert("environment".to_string(), vec![]);
346    rules
347}
348
349/// Builds the default external system prompt.
350pub fn build_default_external_system_prompt() -> String {
351    let base = get_base_prompt();
352    let template = get_external_permissions_template();
353    base.replace("<permissions_template>", &template)
354}