Skip to main content

adk_studio/codegen/
validation.rs

1//! Workflow validation for code generation
2//!
3//! Validates project schemas before code generation to ensure:
4//! - Connected graph structure
5//! - Required fields are present
6//! - Tool configurations are valid
7//!
8//! Requirements: 12.4, 12.5
9
10use crate::schema::{AgentSchema, AgentType, END, ProjectSchema, START, ToolConfig};
11use std::collections::{HashMap, HashSet};
12
13/// Validation error with specific details
14#[derive(Debug, Clone)]
15pub struct ValidationError {
16    /// Error code for categorization
17    pub code: ValidationErrorCode,
18    /// Human-readable error message
19    pub message: String,
20    /// Optional context (e.g., agent ID, field name)
21    pub context: Option<String>,
22}
23
24impl std::fmt::Display for ValidationError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        if let Some(ctx) = &self.context {
27            write!(f, "[{}] {}: {}", self.code, ctx, self.message)
28        } else {
29            write!(f, "[{}] {}", self.code, self.message)
30        }
31    }
32}
33
34/// Error codes for validation errors
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ValidationErrorCode {
37    /// No agents defined in the project
38    NoAgents,
39    /// No edges defined in the workflow
40    NoEdges,
41    /// Missing START edge
42    MissingStartEdge,
43    /// Missing END edge
44    MissingEndEdge,
45    /// Disconnected node (not reachable from START)
46    DisconnectedNode,
47    /// Missing required field
48    MissingRequiredField,
49    /// Invalid tool configuration
50    InvalidToolConfig,
51    /// Invalid route configuration
52    InvalidRouteConfig,
53    /// Circular dependency detected
54    CircularDependency,
55    /// Invalid sub-agent reference
56    InvalidSubAgentRef,
57}
58
59impl std::fmt::Display for ValidationErrorCode {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::NoAgents => write!(f, "NO_AGENTS"),
63            Self::NoEdges => write!(f, "NO_EDGES"),
64            Self::MissingStartEdge => write!(f, "MISSING_START"),
65            Self::MissingEndEdge => write!(f, "MISSING_END"),
66            Self::DisconnectedNode => write!(f, "DISCONNECTED"),
67            Self::MissingRequiredField => write!(f, "MISSING_FIELD"),
68            Self::InvalidToolConfig => write!(f, "INVALID_TOOL"),
69            Self::InvalidRouteConfig => write!(f, "INVALID_ROUTE"),
70            Self::CircularDependency => write!(f, "CIRCULAR_DEP"),
71            Self::InvalidSubAgentRef => write!(f, "INVALID_SUBAGENT"),
72        }
73    }
74}
75
76/// Result of workflow validation
77#[derive(Debug)]
78pub struct ValidationResult {
79    /// List of validation errors
80    pub errors: Vec<ValidationError>,
81    /// List of validation warnings (non-blocking)
82    pub warnings: Vec<ValidationError>,
83}
84
85impl ValidationResult {
86    pub fn new() -> Self {
87        Self {
88            errors: Vec::new(),
89            warnings: Vec::new(),
90        }
91    }
92
93    pub fn is_valid(&self) -> bool {
94        self.errors.is_empty()
95    }
96
97    pub fn add_error(&mut self, code: ValidationErrorCode, message: impl Into<String>) {
98        self.errors.push(ValidationError {
99            code,
100            message: message.into(),
101            context: None,
102        });
103    }
104
105    pub fn add_error_with_context(
106        &mut self,
107        code: ValidationErrorCode,
108        message: impl Into<String>,
109        context: impl Into<String>,
110    ) {
111        self.errors.push(ValidationError {
112            code,
113            message: message.into(),
114            context: Some(context.into()),
115        });
116    }
117
118    pub fn add_warning(&mut self, code: ValidationErrorCode, message: impl Into<String>) {
119        self.warnings.push(ValidationError {
120            code,
121            message: message.into(),
122            context: None,
123        });
124    }
125
126    pub fn add_warning_with_context(
127        &mut self,
128        code: ValidationErrorCode,
129        message: impl Into<String>,
130        context: impl Into<String>,
131    ) {
132        self.warnings.push(ValidationError {
133            code,
134            message: message.into(),
135            context: Some(context.into()),
136        });
137    }
138}
139
140impl Default for ValidationResult {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// Validate a project schema before code generation
147///
148/// Returns a ValidationResult containing any errors or warnings found.
149/// Code generation should only proceed if `result.is_valid()` returns true.
150pub fn validate_project(project: &ProjectSchema) -> ValidationResult {
151    let mut result = ValidationResult::new();
152
153    // Check for empty project
154    validate_not_empty(project, &mut result);
155    if !result.is_valid() {
156        return result;
157    }
158
159    // Validate graph connectivity
160    validate_graph_connectivity(project, &mut result);
161
162    // Validate each agent
163    for (agent_id, agent) in &project.agents {
164        validate_agent(agent_id, agent, project, &mut result);
165    }
166
167    // Validate tool configurations
168    validate_tool_configs(project, &mut result);
169
170    result
171}
172
173/// Check that the project has at least one agent OR action node, and at least one edge
174fn validate_not_empty(project: &ProjectSchema, result: &mut ValidationResult) {
175    let has_agents = !project.agents.is_empty();
176    let has_action_nodes = !project.action_nodes.is_empty();
177
178    // Allow workflows with either agents OR action nodes
179    if !has_agents && !has_action_nodes {
180        result.add_error(
181            ValidationErrorCode::NoAgents,
182            "Project must have at least one agent or action node",
183        );
184    }
185
186    if project.workflow.edges.is_empty() {
187        result.add_error(
188            ValidationErrorCode::NoEdges,
189            "Workflow must have at least one edge",
190        );
191    }
192}
193
194/// Validate that the graph is connected (all nodes reachable from START)
195fn validate_graph_connectivity(project: &ProjectSchema, result: &mut ValidationResult) {
196    // Build adjacency list
197    let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
198    let mut has_start_edge = false;
199    let mut has_end_edge = false;
200
201    for edge in &project.workflow.edges {
202        if edge.from == START {
203            has_start_edge = true;
204        }
205        if edge.to == END {
206            has_end_edge = true;
207        }
208        adjacency
209            .entry(edge.from.as_str())
210            .or_default()
211            .push(edge.to.as_str());
212    }
213
214    // Check for START and END edges
215    if !has_start_edge {
216        result.add_error(
217            ValidationErrorCode::MissingStartEdge,
218            "Workflow must have an edge from START",
219        );
220    }
221
222    if !has_end_edge {
223        result.add_error(
224            ValidationErrorCode::MissingEndEdge,
225            "Workflow must have an edge to END",
226        );
227    }
228
229    // Find all top-level agents (not sub-agents)
230    let all_sub_agents: HashSet<_> = project
231        .agents
232        .values()
233        .flat_map(|a| a.sub_agents.iter().map(|s| s.as_str()))
234        .collect();
235
236    let top_level_agents: HashSet<_> = project
237        .agents
238        .keys()
239        .filter(|id| !all_sub_agents.contains(id.as_str()))
240        .collect();
241
242    // BFS from START to find reachable nodes
243    let mut reachable: HashSet<&str> = HashSet::new();
244    let mut queue: Vec<&str> = vec![START];
245
246    while let Some(node) = queue.pop() {
247        if reachable.contains(node) {
248            continue;
249        }
250        reachable.insert(node);
251
252        if let Some(neighbors) = adjacency.get(node) {
253            for neighbor in neighbors {
254                if !reachable.contains(neighbor) {
255                    queue.push(neighbor);
256                }
257            }
258        }
259    }
260
261    // Check that all top-level agents are reachable
262    // Note: Action nodes (like trigger) are also valid nodes in the graph
263    for agent_id in &top_level_agents {
264        // Skip if this agent is reachable through an action node
265        // (e.g., START -> trigger -> agent)
266        let reachable_through_action = project.action_nodes.keys().any(|action_id| {
267            reachable.contains(action_id.as_str())
268                && project
269                    .workflow
270                    .edges
271                    .iter()
272                    .any(|e| e.from == *action_id && e.to == **agent_id)
273        });
274
275        if !reachable.contains(agent_id.as_str()) && !reachable_through_action {
276            result.add_error_with_context(
277                ValidationErrorCode::DisconnectedNode,
278                "Agent is not reachable from START",
279                agent_id.as_str(),
280            );
281        }
282    }
283}
284
285/// Validate a single agent's configuration
286fn validate_agent(
287    agent_id: &str,
288    agent: &AgentSchema,
289    project: &ProjectSchema,
290    result: &mut ValidationResult,
291) {
292    match agent.agent_type {
293        AgentType::Llm => validate_llm_agent(agent_id, agent, result),
294        AgentType::Router => validate_router_agent(agent_id, agent, project, result),
295        AgentType::Sequential | AgentType::Loop | AgentType::Parallel => {
296            validate_container_agent(agent_id, agent, project, result)
297        }
298        _ => {}
299    }
300}
301
302/// Validate LLM agent configuration
303fn validate_llm_agent(agent_id: &str, agent: &AgentSchema, result: &mut ValidationResult) {
304    // LLM agents should have a model specified
305    if agent.model.is_none() {
306        result.add_warning_with_context(
307            ValidationErrorCode::MissingRequiredField,
308            "LLM agent has no model specified, will use default",
309            agent_id,
310        );
311    }
312
313    // Check for empty instruction (warning, not error)
314    if agent.instruction.trim().is_empty() {
315        result.add_warning_with_context(
316            ValidationErrorCode::MissingRequiredField,
317            "LLM agent has no instruction, behavior may be unpredictable",
318            agent_id,
319        );
320    }
321}
322
323/// Validate router agent configuration
324fn validate_router_agent(
325    agent_id: &str,
326    agent: &AgentSchema,
327    project: &ProjectSchema,
328    result: &mut ValidationResult,
329) {
330    // Router must have routes defined
331    if agent.routes.is_empty() {
332        result.add_error_with_context(
333            ValidationErrorCode::InvalidRouteConfig,
334            "Router agent must have at least one route defined",
335            agent_id,
336        );
337        return;
338    }
339
340    // Validate each route target exists
341    for route in &agent.routes {
342        if route.target != END && !project.agents.contains_key(&route.target) {
343            result.add_error_with_context(
344                ValidationErrorCode::InvalidRouteConfig,
345                format!("Route target '{}' does not exist", route.target),
346                agent_id,
347            );
348        }
349
350        // Check for empty condition
351        if route.condition.trim().is_empty() {
352            result.add_error_with_context(
353                ValidationErrorCode::InvalidRouteConfig,
354                "Route condition cannot be empty",
355                agent_id,
356            );
357        }
358    }
359}
360
361/// Validate container agent (Sequential, Loop, Parallel) configuration
362fn validate_container_agent(
363    agent_id: &str,
364    agent: &AgentSchema,
365    project: &ProjectSchema,
366    result: &mut ValidationResult,
367) {
368    // Container must have sub-agents
369    if agent.sub_agents.is_empty() {
370        result.add_error_with_context(
371            ValidationErrorCode::MissingRequiredField,
372            format!(
373                "{:?} agent must have at least one sub-agent",
374                agent.agent_type
375            ),
376            agent_id,
377        );
378        return;
379    }
380
381    // Validate each sub-agent exists
382    for sub_id in &agent.sub_agents {
383        if !project.agents.contains_key(sub_id) {
384            result.add_error_with_context(
385                ValidationErrorCode::InvalidSubAgentRef,
386                format!("Sub-agent '{}' does not exist", sub_id),
387                agent_id,
388            );
389        }
390    }
391
392    // For loop agents, check max_iterations
393    if agent.agent_type == AgentType::Loop {
394        if let Some(max_iter) = agent.max_iterations {
395            if max_iter == 0 {
396                result.add_warning_with_context(
397                    ValidationErrorCode::MissingRequiredField,
398                    "Loop agent has max_iterations=0, will not execute",
399                    agent_id,
400                );
401            }
402        }
403    }
404}
405
406/// Validate tool configurations
407fn validate_tool_configs(project: &ProjectSchema, result: &mut ValidationResult) {
408    for (tool_id, config) in &project.tool_configs {
409        match config {
410            ToolConfig::Mcp(mcp) => {
411                if mcp.server_command.trim().is_empty() {
412                    result.add_error_with_context(
413                        ValidationErrorCode::InvalidToolConfig,
414                        "MCP tool must have a server command",
415                        tool_id,
416                    );
417                }
418            }
419            ToolConfig::Function(func) => {
420                if func.name.trim().is_empty() {
421                    result.add_error_with_context(
422                        ValidationErrorCode::InvalidToolConfig,
423                        "Function tool must have a name",
424                        tool_id,
425                    );
426                }
427                if func.description.trim().is_empty() {
428                    result.add_warning_with_context(
429                        ValidationErrorCode::InvalidToolConfig,
430                        "Function tool has no description",
431                        tool_id,
432                    );
433                }
434            }
435            ToolConfig::Browser(_) => {
436                // Browser config has sensible defaults, no validation needed
437            }
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::schema::{Edge, Position, Route};
446
447    fn create_test_project() -> ProjectSchema {
448        let mut project = ProjectSchema::new("test");
449        project.agents.insert(
450            "agent1".to_string(),
451            AgentSchema {
452                agent_type: AgentType::Llm,
453                model: Some("gemini-3.1-flash-lite-preview".to_string()),
454                instruction: "Test instruction".to_string(),
455                tools: vec![],
456                sub_agents: vec![],
457                position: Position::default(),
458                max_iterations: None,
459                temperature: None,
460                top_p: None,
461                top_k: None,
462                max_output_tokens: None,
463                routes: vec![],
464            },
465        );
466        project.workflow.edges = vec![Edge::new(START, "agent1"), Edge::new("agent1", END)];
467        project
468    }
469
470    #[test]
471    fn test_valid_project() {
472        let project = create_test_project();
473        let result = validate_project(&project);
474        assert!(
475            result.is_valid(),
476            "Expected valid project, got errors: {:?}",
477            result.errors
478        );
479    }
480
481    #[test]
482    fn test_empty_agents() {
483        let mut project = create_test_project();
484        project.agents.clear();
485        let result = validate_project(&project);
486        assert!(!result.is_valid());
487        assert!(
488            result
489                .errors
490                .iter()
491                .any(|e| e.code == ValidationErrorCode::NoAgents)
492        );
493    }
494
495    #[test]
496    fn test_empty_edges() {
497        let mut project = create_test_project();
498        project.workflow.edges.clear();
499        let result = validate_project(&project);
500        assert!(!result.is_valid());
501        assert!(
502            result
503                .errors
504                .iter()
505                .any(|e| e.code == ValidationErrorCode::NoEdges)
506        );
507    }
508
509    #[test]
510    fn test_missing_start_edge() {
511        let mut project = create_test_project();
512        project.workflow.edges = vec![Edge::new("agent1", END)];
513        let result = validate_project(&project);
514        assert!(!result.is_valid());
515        assert!(
516            result
517                .errors
518                .iter()
519                .any(|e| e.code == ValidationErrorCode::MissingStartEdge)
520        );
521    }
522
523    #[test]
524    fn test_missing_end_edge() {
525        let mut project = create_test_project();
526        project.workflow.edges = vec![Edge::new(START, "agent1")];
527        let result = validate_project(&project);
528        assert!(!result.is_valid());
529        assert!(
530            result
531                .errors
532                .iter()
533                .any(|e| e.code == ValidationErrorCode::MissingEndEdge)
534        );
535    }
536
537    #[test]
538    fn test_disconnected_node() {
539        let mut project = create_test_project();
540        project.agents.insert(
541            "agent2".to_string(),
542            AgentSchema {
543                agent_type: AgentType::Llm,
544                model: Some("gemini-3.1-flash-lite-preview".to_string()),
545                instruction: "Disconnected".to_string(),
546                tools: vec![],
547                sub_agents: vec![],
548                position: Position::default(),
549                max_iterations: None,
550                temperature: None,
551                top_p: None,
552                top_k: None,
553                max_output_tokens: None,
554                routes: vec![],
555            },
556        );
557        // agent2 is not connected to the graph
558        let result = validate_project(&project);
559        assert!(!result.is_valid());
560        assert!(
561            result
562                .errors
563                .iter()
564                .any(|e| e.code == ValidationErrorCode::DisconnectedNode)
565        );
566    }
567
568    #[test]
569    fn test_router_without_routes() {
570        let mut project = create_test_project();
571        project.agents.insert(
572            "router".to_string(),
573            AgentSchema {
574                agent_type: AgentType::Router,
575                model: Some("gemini-3.1-flash-lite-preview".to_string()),
576                instruction: "Route".to_string(),
577                tools: vec![],
578                sub_agents: vec![],
579                position: Position::default(),
580                max_iterations: None,
581                temperature: None,
582                top_p: None,
583                top_k: None,
584                max_output_tokens: None,
585                routes: vec![], // No routes!
586            },
587        );
588        project.workflow.edges = vec![Edge::new(START, "router"), Edge::new("router", END)];
589        let result = validate_project(&project);
590        assert!(!result.is_valid());
591        assert!(
592            result
593                .errors
594                .iter()
595                .any(|e| e.code == ValidationErrorCode::InvalidRouteConfig)
596        );
597    }
598
599    #[test]
600    fn test_router_with_invalid_target() {
601        let mut project = create_test_project();
602        project.agents.insert(
603            "router".to_string(),
604            AgentSchema {
605                agent_type: AgentType::Router,
606                model: Some("gemini-3.1-flash-lite-preview".to_string()),
607                instruction: "Route".to_string(),
608                tools: vec![],
609                sub_agents: vec![],
610                position: Position::default(),
611                max_iterations: None,
612                temperature: None,
613                top_p: None,
614                top_k: None,
615                max_output_tokens: None,
616                routes: vec![Route {
617                    condition: "test".to_string(),
618                    target: "nonexistent".to_string(), // Invalid target
619                }],
620            },
621        );
622        project.workflow.edges = vec![Edge::new(START, "router"), Edge::new("router", END)];
623        let result = validate_project(&project);
624        assert!(!result.is_valid());
625        assert!(
626            result
627                .errors
628                .iter()
629                .any(|e| e.code == ValidationErrorCode::InvalidRouteConfig)
630        );
631    }
632
633    #[test]
634    fn test_sequential_without_subagents() {
635        let mut project = create_test_project();
636        project.agents.insert(
637            "seq".to_string(),
638            AgentSchema {
639                agent_type: AgentType::Sequential,
640                model: None,
641                instruction: String::new(),
642                tools: vec![],
643                sub_agents: vec![], // No sub-agents!
644                position: Position::default(),
645                max_iterations: None,
646                temperature: None,
647                top_p: None,
648                top_k: None,
649                max_output_tokens: None,
650                routes: vec![],
651            },
652        );
653        project.workflow.edges = vec![Edge::new(START, "seq"), Edge::new("seq", END)];
654        let result = validate_project(&project);
655        assert!(!result.is_valid());
656        assert!(
657            result
658                .errors
659                .iter()
660                .any(|e| e.code == ValidationErrorCode::MissingRequiredField)
661        );
662    }
663
664    #[test]
665    fn test_sequential_with_invalid_subagent() {
666        let mut project = create_test_project();
667        project.agents.insert(
668            "seq".to_string(),
669            AgentSchema {
670                agent_type: AgentType::Sequential,
671                model: None,
672                instruction: String::new(),
673                tools: vec![],
674                sub_agents: vec!["nonexistent".to_string()], // Invalid sub-agent
675                position: Position::default(),
676                max_iterations: None,
677                temperature: None,
678                top_p: None,
679                top_k: None,
680                max_output_tokens: None,
681                routes: vec![],
682            },
683        );
684        project.workflow.edges = vec![Edge::new(START, "seq"), Edge::new("seq", END)];
685        let result = validate_project(&project);
686        assert!(!result.is_valid());
687        assert!(
688            result
689                .errors
690                .iter()
691                .any(|e| e.code == ValidationErrorCode::InvalidSubAgentRef)
692        );
693    }
694}
695
696/// Get a list of required environment variables for a project
697///
698/// This function analyzes the project schema and returns a list of
699/// environment variables that must be set for the generated code to run.
700///
701/// Requirement: 12.10 - Warn when required env vars are missing
702pub fn get_required_env_vars(project: &ProjectSchema) -> Vec<EnvVarRequirement> {
703    let mut env_vars = Vec::new();
704
705    // Detect which providers are used across all agents
706    let providers = super::collect_providers(project);
707
708    if providers.contains("gemini") {
709        env_vars.push(EnvVarRequirement {
710            name: "GOOGLE_API_KEY".to_string(),
711            description: "Google AI API key for Gemini models".to_string(),
712            alternatives: vec!["GEMINI_API_KEY".to_string()],
713            required: true,
714        });
715    }
716
717    if providers.contains("openai") {
718        env_vars.push(EnvVarRequirement {
719            name: "OPENAI_API_KEY".to_string(),
720            description: "OpenAI API key for GPT models".to_string(),
721            alternatives: vec![],
722            required: true,
723        });
724    }
725
726    if providers.contains("anthropic") {
727        env_vars.push(EnvVarRequirement {
728            name: "ANTHROPIC_API_KEY".to_string(),
729            description: "Anthropic API key for Claude models".to_string(),
730            alternatives: vec![],
731            required: true,
732        });
733    }
734
735    if providers.contains("deepseek") {
736        env_vars.push(EnvVarRequirement {
737            name: "DEEPSEEK_API_KEY".to_string(),
738            description: "DeepSeek API key".to_string(),
739            alternatives: vec![],
740            required: true,
741        });
742    }
743
744    if providers.contains("groq") {
745        env_vars.push(EnvVarRequirement {
746            name: "GROQ_API_KEY".to_string(),
747            description: "Groq API key for fast inference".to_string(),
748            alternatives: vec![],
749            required: true,
750        });
751    }
752
753    if providers.contains("ollama") {
754        env_vars.push(EnvVarRequirement {
755            name: "OLLAMA_HOST".to_string(),
756            description: "Ollama server URL (defaults to http://localhost:11434)".to_string(),
757            alternatives: vec![],
758            required: false, // Ollama defaults to localhost
759        });
760    }
761
762    if providers.contains("fireworks") {
763        env_vars.push(EnvVarRequirement {
764            name: "FIREWORKS_API_KEY".to_string(),
765            description: "Fireworks AI API key".to_string(),
766            alternatives: vec![],
767            required: true,
768        });
769    }
770
771    if providers.contains("together") {
772        env_vars.push(EnvVarRequirement {
773            name: "TOGETHER_API_KEY".to_string(),
774            description: "Together AI API key".to_string(),
775            alternatives: vec![],
776            required: true,
777        });
778    }
779
780    if providers.contains("mistral") {
781        env_vars.push(EnvVarRequirement {
782            name: "MISTRAL_API_KEY".to_string(),
783            description: "Mistral AI API key".to_string(),
784            alternatives: vec![],
785            required: true,
786        });
787    }
788
789    if providers.contains("perplexity") {
790        env_vars.push(EnvVarRequirement {
791            name: "PERPLEXITY_API_KEY".to_string(),
792            description: "Perplexity API key for Sonar models".to_string(),
793            alternatives: vec![],
794            required: true,
795        });
796    }
797
798    if providers.contains("cerebras") {
799        env_vars.push(EnvVarRequirement {
800            name: "CEREBRAS_API_KEY".to_string(),
801            description: "Cerebras API key for ultra-fast inference".to_string(),
802            alternatives: vec![],
803            required: true,
804        });
805    }
806
807    if providers.contains("sambanova") {
808        env_vars.push(EnvVarRequirement {
809            name: "SAMBANOVA_API_KEY".to_string(),
810            description: "SambaNova API key".to_string(),
811            alternatives: vec![],
812            required: true,
813        });
814    }
815
816    if providers.contains("bedrock") {
817        env_vars.push(EnvVarRequirement {
818            name: "AWS_ACCESS_KEY_ID".to_string(),
819            description: "AWS credentials for Amazon Bedrock (or use IAM roles/SSO)".to_string(),
820            alternatives: vec!["AWS_PROFILE".to_string()],
821            required: true,
822        });
823        env_vars.push(EnvVarRequirement {
824            name: "AWS_DEFAULT_REGION".to_string(),
825            description: "AWS region for Bedrock (defaults to us-east-1)".to_string(),
826            alternatives: vec!["AWS_REGION".to_string()],
827            required: false,
828        });
829    }
830
831    if providers.contains("azure-ai") {
832        env_vars.push(EnvVarRequirement {
833            name: "AZURE_AI_ENDPOINT".to_string(),
834            description: "Azure AI Inference endpoint URL".to_string(),
835            alternatives: vec![],
836            required: true,
837        });
838        env_vars.push(EnvVarRequirement {
839            name: "AZURE_AI_API_KEY".to_string(),
840            description: "Azure AI API key".to_string(),
841            alternatives: vec![],
842            required: true,
843        });
844    }
845
846    // Check for MCP tools that might need specific env vars
847    for (tool_id, config) in &project.tool_configs {
848        if let ToolConfig::Mcp(mcp) = config {
849            // Common MCP servers that need env vars
850            if mcp.server_command.contains("github")
851                || mcp.server_args.iter().any(|a| a.contains("github"))
852            {
853                env_vars.push(EnvVarRequirement {
854                    name: "GITHUB_TOKEN".to_string(),
855                    description: format!("GitHub token for MCP server ({})", tool_id),
856                    alternatives: vec!["GITHUB_PERSONAL_ACCESS_TOKEN".to_string()],
857                    required: false, // May work without for public repos
858                });
859            }
860
861            if mcp.server_command.contains("slack")
862                || mcp.server_args.iter().any(|a| a.contains("slack"))
863            {
864                env_vars.push(EnvVarRequirement {
865                    name: "SLACK_BOT_TOKEN".to_string(),
866                    description: format!("Slack bot token for MCP server ({})", tool_id),
867                    alternatives: vec![],
868                    required: true,
869                });
870            }
871        }
872    }
873
874    // Check for browser tool (may need headless browser setup)
875    let uses_browser = project
876        .agents
877        .values()
878        .any(|a| a.tools.contains(&"browser".to_string()));
879    if uses_browser {
880        env_vars.push(EnvVarRequirement {
881            name: "CHROME_PATH".to_string(),
882            description: "Path to Chrome/Chromium executable (optional, auto-detected if not set)"
883                .to_string(),
884            alternatives: vec!["CHROMIUM_PATH".to_string()],
885            required: false,
886        });
887    }
888
889    env_vars
890}
891
892/// Environment variable requirement
893#[derive(Debug, Clone)]
894pub struct EnvVarRequirement {
895    /// Primary environment variable name
896    pub name: String,
897    /// Description of what this variable is used for
898    pub description: String,
899    /// Alternative variable names that can be used
900    pub alternatives: Vec<String>,
901    /// Whether this variable is required (vs optional)
902    pub required: bool,
903}
904
905impl EnvVarRequirement {
906    /// Check if this environment variable is set
907    pub fn is_set(&self) -> bool {
908        if std::env::var(&self.name).is_ok() {
909            return true;
910        }
911        self.alternatives
912            .iter()
913            .any(|alt| std::env::var(alt).is_ok())
914    }
915
916    /// Get all possible variable names (primary + alternatives)
917    pub fn all_names(&self) -> Vec<&str> {
918        let mut names = vec![self.name.as_str()];
919        names.extend(self.alternatives.iter().map(|s| s.as_str()));
920        names
921    }
922}
923
924/// Check for missing required environment variables
925///
926/// Returns a list of warnings for missing environment variables.
927/// This is used to warn users before they try to run the generated code.
928pub fn check_env_vars(project: &ProjectSchema) -> Vec<EnvVarWarning> {
929    let requirements = get_required_env_vars(project);
930    let mut warnings = Vec::new();
931
932    for req in requirements {
933        if !req.is_set() {
934            warnings.push(EnvVarWarning {
935                variable: req.name.clone(),
936                description: req.description.clone(),
937                alternatives: req.alternatives.clone(),
938                required: req.required,
939            });
940        }
941    }
942
943    warnings
944}
945
946/// Warning about a missing environment variable
947#[derive(Debug, Clone, serde::Serialize)]
948pub struct EnvVarWarning {
949    /// Primary environment variable name
950    pub variable: String,
951    /// Description of what this variable is used for
952    pub description: String,
953    /// Alternative variable names that can be used
954    pub alternatives: Vec<String>,
955    /// Whether this variable is required (vs optional)
956    pub required: bool,
957}
958
959impl std::fmt::Display for EnvVarWarning {
960    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
961        if self.required {
962            write!(f, "Required: {} - {}", self.variable, self.description)?;
963        } else {
964            write!(f, "Optional: {} - {}", self.variable, self.description)?;
965        }
966        if !self.alternatives.is_empty() {
967            write!(f, " (alternatives: {})", self.alternatives.join(", "))?;
968        }
969        Ok(())
970    }
971}
972
973#[cfg(test)]
974mod env_var_tests {
975    use super::*;
976    use crate::schema::{McpToolConfig, Position};
977
978    #[test]
979    fn test_gemini_requires_api_key() {
980        let mut project = ProjectSchema::new("test");
981        project.agents.insert(
982            "agent".to_string(),
983            AgentSchema {
984                agent_type: AgentType::Llm,
985                model: Some("gemini-3.1-flash-lite-preview".to_string()),
986                instruction: "Test".to_string(),
987                tools: vec![],
988                sub_agents: vec![],
989                position: Position::default(),
990                max_iterations: None,
991                temperature: None,
992                top_p: None,
993                top_k: None,
994                max_output_tokens: None,
995                routes: vec![],
996            },
997        );
998
999        let env_vars = get_required_env_vars(&project);
1000        assert!(env_vars.iter().any(|v| v.name == "GOOGLE_API_KEY"));
1001    }
1002
1003    #[test]
1004    fn test_browser_tool_env_var() {
1005        let mut project = ProjectSchema::new("test");
1006        project.agents.insert(
1007            "agent".to_string(),
1008            AgentSchema {
1009                agent_type: AgentType::Llm,
1010                model: Some("gemini-3.1-flash-lite-preview".to_string()),
1011                instruction: "Test".to_string(),
1012                tools: vec!["browser".to_string()],
1013                sub_agents: vec![],
1014                position: Position::default(),
1015                max_iterations: None,
1016                temperature: None,
1017                top_p: None,
1018                top_k: None,
1019                max_output_tokens: None,
1020                routes: vec![],
1021            },
1022        );
1023
1024        let env_vars = get_required_env_vars(&project);
1025        assert!(env_vars.iter().any(|v| v.name == "CHROME_PATH"));
1026    }
1027
1028    #[test]
1029    fn test_github_mcp_env_var() {
1030        let mut project = ProjectSchema::new("test");
1031        project.agents.insert(
1032            "agent".to_string(),
1033            AgentSchema {
1034                agent_type: AgentType::Llm,
1035                model: Some("gemini-3.1-flash-lite-preview".to_string()),
1036                instruction: "Test".to_string(),
1037                tools: vec!["mcp".to_string()],
1038                sub_agents: vec![],
1039                position: Position::default(),
1040                max_iterations: None,
1041                temperature: None,
1042                top_p: None,
1043                top_k: None,
1044                max_output_tokens: None,
1045                routes: vec![],
1046            },
1047        );
1048        project.tool_configs.insert(
1049            "agent_mcp".to_string(),
1050            ToolConfig::Mcp(McpToolConfig {
1051                server_command: "npx".to_string(),
1052                server_args: vec![
1053                    "-y".to_string(),
1054                    "@modelcontextprotocol/server-github".to_string(),
1055                ],
1056                tool_filter: vec![],
1057            }),
1058        );
1059
1060        let env_vars = get_required_env_vars(&project);
1061        assert!(env_vars.iter().any(|v| v.name == "GITHUB_TOKEN"));
1062    }
1063}