ricecoder_generation/
prompt_builder.rs

1//! Prompt building for AI code generation
2//!
3//! Builds system prompts from specifications and design documents,
4//! includes project context and examples, and applies steering rules.
5
6use crate::error::GenerationError;
7use crate::spec_processor::GenerationPlan;
8use ricecoder_storage::PathResolver;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13/// Builds prompts for AI code generation
14#[derive(Debug, Clone)]
15pub struct PromptBuilder {
16    /// Maximum tokens for context
17    pub max_context_tokens: usize,
18    /// Project root path
19    pub project_root: PathBuf,
20}
21
22/// A built prompt for AI generation
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GeneratedPrompt {
25    /// Unique identifier
26    pub id: String,
27    /// System prompt (instructions for AI)
28    pub system_prompt: String,
29    /// User prompt (the actual request)
30    pub user_prompt: String,
31    /// Context included in the prompt
32    pub context: PromptContext,
33    /// Steering rules applied
34    pub steering_rules_applied: Vec<String>,
35    /// Estimated token count
36    pub estimated_tokens: usize,
37}
38
39/// Context included in a prompt
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct PromptContext {
42    /// Spec content
43    pub spec_content: Option<String>,
44    /// Design content
45    pub design_content: Option<String>,
46    /// Project examples
47    pub examples: Vec<String>,
48    /// Architecture documentation
49    pub architecture_docs: Vec<String>,
50    /// Steering rules content
51    pub steering_rules: Vec<String>,
52}
53
54/// Steering rules loaded from files
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SteeringRules {
57    /// Naming conventions (e.g., snake_case for Rust)
58    pub naming_conventions: HashMap<String, String>,
59    /// Code quality standards
60    pub code_quality_standards: Vec<String>,
61    /// Documentation requirements
62    pub documentation_requirements: Vec<String>,
63    /// Error handling patterns
64    pub error_handling_patterns: Vec<String>,
65    /// Testing requirements
66    pub testing_requirements: Vec<String>,
67}
68
69impl PromptBuilder {
70    /// Creates a new PromptBuilder
71    pub fn new(project_root: PathBuf) -> Self {
72        Self {
73            max_context_tokens: 4000,
74            project_root,
75        }
76    }
77
78    /// Sets the maximum context tokens
79    pub fn with_max_context_tokens(mut self, tokens: usize) -> Self {
80        self.max_context_tokens = tokens;
81        self
82    }
83
84    /// Builds a prompt from a generation plan
85    ///
86    /// # Arguments
87    ///
88    /// * `plan` - The generation plan
89    /// * `spec_content` - Optional spec file content
90    /// * `design_content` - Optional design file content
91    ///
92    /// # Returns
93    ///
94    /// A generated prompt ready for AI
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if prompt building fails
99    pub fn build(
100        &self,
101        plan: &GenerationPlan,
102        spec_content: Option<&str>,
103        design_content: Option<&str>,
104    ) -> Result<GeneratedPrompt, GenerationError> {
105        // Load steering rules
106        let steering_rules = self.load_steering_rules()?;
107
108        // Build context
109        let context = self.build_context(spec_content, design_content, &steering_rules)?;
110
111        // Build system prompt
112        let system_prompt = self.build_system_prompt(&steering_rules)?;
113
114        // Build user prompt from generation plan
115        let user_prompt = self.build_user_prompt(plan)?;
116
117        // Estimate tokens
118        let estimated_tokens = self.estimate_tokens(&system_prompt, &user_prompt, &context);
119
120        // Check token budget
121        if estimated_tokens > self.max_context_tokens {
122            return Err(GenerationError::PromptError(format!(
123                "Prompt exceeds token budget: {} > {}",
124                estimated_tokens, self.max_context_tokens
125            )));
126        }
127
128        let steering_rules_applied = steering_rules.naming_conventions.keys().cloned().collect();
129
130        Ok(GeneratedPrompt {
131            id: format!("prompt-{}", uuid::Uuid::new_v4()),
132            system_prompt,
133            user_prompt,
134            context,
135            steering_rules_applied,
136            estimated_tokens,
137        })
138    }
139
140    /// Loads steering rules from project and workspace
141    pub fn load_steering_rules(&self) -> Result<SteeringRules, GenerationError> {
142        let mut rules = SteeringRules {
143            naming_conventions: HashMap::new(),
144            code_quality_standards: Vec::new(),
145            documentation_requirements: Vec::new(),
146            error_handling_patterns: Vec::new(),
147            testing_requirements: Vec::new(),
148        };
149
150        // Load project-level steering rules using PathResolver
151        let project_path = PathResolver::resolve_project_path();
152        let project_steering_dir = self.project_root.join(&project_path).join("steering");
153        if project_steering_dir.exists() {
154            self.load_steering_from_dir(&project_steering_dir, &mut rules)?;
155        }
156
157        // Load global-level steering rules using PathResolver
158        match PathResolver::resolve_global_path() {
159            Ok(global_path) => {
160                let global_steering_dir = global_path.join("steering");
161                if global_steering_dir.exists() {
162                    self.load_steering_from_dir(&global_steering_dir, &mut rules)?;
163                }
164            }
165            Err(_) => {
166                // If global path resolution fails, continue with project rules only
167                // This is not a fatal error - we can still use project-level rules
168            }
169        }
170
171        // Set default naming conventions if not loaded
172        if rules.naming_conventions.is_empty() {
173            rules
174                .naming_conventions
175                .insert("rust".to_string(), "snake_case".to_string());
176            rules
177                .naming_conventions
178                .insert("typescript".to_string(), "camelCase".to_string());
179            rules
180                .naming_conventions
181                .insert("python".to_string(), "snake_case".to_string());
182        }
183
184        Ok(rules)
185    }
186
187    /// Loads steering rules from a directory
188    fn load_steering_from_dir(
189        &self,
190        _dir: &Path,
191        rules: &mut SteeringRules,
192    ) -> Result<(), GenerationError> {
193        // In a real implementation, this would read YAML/Markdown files
194        // For now, we'll just set defaults
195        if !rules.code_quality_standards.is_empty() {
196            return Ok(());
197        }
198
199        rules.code_quality_standards = vec![
200            "Zero warnings in production code".to_string(),
201            "All public APIs must have tests".to_string(),
202            "Type safety first - use strict type checking".to_string(),
203        ];
204
205        rules.documentation_requirements = vec![
206            "All public types must have doc comments".to_string(),
207            "All public functions must have doc comments".to_string(),
208            "Complex logic must have explanatory comments".to_string(),
209        ];
210
211        rules.error_handling_patterns = vec![
212            "Use explicit error types (not generic String errors)".to_string(),
213            "Never silently swallow errors".to_string(),
214            "Propagate errors with context".to_string(),
215        ];
216
217        rules.testing_requirements = vec![
218            "Unit tests for all public APIs".to_string(),
219            "Integration tests for workflows".to_string(),
220            "Property tests for deterministic operations".to_string(),
221        ];
222
223        Ok(())
224    }
225
226    /// Builds the context for the prompt
227    fn build_context(
228        &self,
229        spec_content: Option<&str>,
230        design_content: Option<&str>,
231        steering_rules: &SteeringRules,
232    ) -> Result<PromptContext, GenerationError> {
233        let mut context = PromptContext {
234            spec_content: spec_content.map(|s| s.to_string()),
235            design_content: design_content.map(|s| s.to_string()),
236            examples: Vec::new(),
237            architecture_docs: Vec::new(),
238            steering_rules: Vec::new(),
239        };
240
241        // Add steering rules to context
242        for (lang, convention) in &steering_rules.naming_conventions {
243            context.steering_rules.push(format!(
244                "For {}: use {} naming convention",
245                lang, convention
246            ));
247        }
248
249        for standard in &steering_rules.code_quality_standards {
250            context.steering_rules.push(standard.clone());
251        }
252
253        for requirement in &steering_rules.documentation_requirements {
254            context.steering_rules.push(requirement.clone());
255        }
256
257        Ok(context)
258    }
259
260    /// Builds the system prompt
261    fn build_system_prompt(
262        &self,
263        steering_rules: &SteeringRules,
264    ) -> Result<String, GenerationError> {
265        let mut prompt = String::new();
266
267        prompt.push_str("You are an expert code generation assistant.\n\n");
268
269        prompt.push_str("Your task is to generate high-quality code that:\n");
270        for standard in &steering_rules.code_quality_standards {
271            prompt.push_str(&format!("- {}\n", standard));
272        }
273
274        prompt.push_str("\nDocumentation Requirements:\n");
275        for requirement in &steering_rules.documentation_requirements {
276            prompt.push_str(&format!("- {}\n", requirement));
277        }
278
279        prompt.push_str("\nError Handling:\n");
280        for pattern in &steering_rules.error_handling_patterns {
281            prompt.push_str(&format!("- {}\n", pattern));
282        }
283
284        prompt.push_str("\nTesting:\n");
285        for requirement in &steering_rules.testing_requirements {
286            prompt.push_str(&format!("- {}\n", requirement));
287        }
288
289        prompt.push_str("\nNaming Conventions:\n");
290        for (lang, convention) in &steering_rules.naming_conventions {
291            prompt.push_str(&format!("- {}: {}\n", lang, convention));
292        }
293
294        Ok(prompt)
295    }
296
297    /// Builds the user prompt from a generation plan
298    fn build_user_prompt(&self, plan: &GenerationPlan) -> Result<String, GenerationError> {
299        let mut prompt = String::new();
300
301        prompt.push_str("Generate code for the following specification:\n\n");
302
303        for step in &plan.steps {
304            prompt.push_str(&format!("## {}\n", step.description));
305            prompt.push_str(&format!("Priority: {:?}\n", step.priority));
306
307            if !step.acceptance_criteria.is_empty() {
308                prompt.push_str("\nAcceptance Criteria:\n");
309                for criterion in &step.acceptance_criteria {
310                    prompt.push_str(&format!(
311                        "- WHEN {} THEN {}\n",
312                        criterion.when, criterion.then
313                    ));
314                }
315            }
316
317            prompt.push('\n');
318        }
319
320        Ok(prompt)
321    }
322
323    /// Estimates token count for a prompt
324    fn estimate_tokens(
325        &self,
326        system_prompt: &str,
327        user_prompt: &str,
328        context: &PromptContext,
329    ) -> usize {
330        // Rough estimation: ~4 characters per token
331        let mut total = 0;
332
333        total += system_prompt.len() / 4;
334        total += user_prompt.len() / 4;
335
336        if let Some(spec) = &context.spec_content {
337            total += spec.len() / 4;
338        }
339
340        if let Some(design) = &context.design_content {
341            total += design.len() / 4;
342        }
343
344        for example in &context.examples {
345            total += example.len() / 4;
346        }
347
348        for rule in &context.steering_rules {
349            total += rule.len() / 4;
350        }
351
352        total
353    }
354}
355
356impl Default for PromptBuilder {
357    fn default() -> Self {
358        Self::new(PathBuf::from("."))
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use ricecoder_specs::models::{AcceptanceCriterion, Priority};
366
367    fn create_test_plan() -> GenerationPlan {
368        GenerationPlan {
369            id: "test-plan".to_string(),
370            spec_id: "test-spec".to_string(),
371            steps: vec![crate::spec_processor::GenerationStep {
372                id: "step-1".to_string(),
373                description: "Implement user authentication".to_string(),
374                requirement_ids: vec!["req-1".to_string()],
375                acceptance_criteria: vec![AcceptanceCriterion {
376                    id: "ac-1".to_string(),
377                    when: "user provides credentials".to_string(),
378                    then: "system authenticates user".to_string(),
379                }],
380                priority: Priority::Must,
381                optional: false,
382                sequence: 0,
383            }],
384            dependencies: vec![],
385            constraints: vec![],
386        }
387    }
388
389    #[test]
390    fn test_prompt_builder_creates_prompt() {
391        let builder = PromptBuilder::default();
392        let plan = create_test_plan();
393
394        let prompt = builder
395            .build(&plan, None, None)
396            .expect("Failed to build prompt");
397
398        assert!(!prompt.system_prompt.is_empty());
399        assert!(!prompt.user_prompt.is_empty());
400        assert!(!prompt.steering_rules_applied.is_empty());
401    }
402
403    #[test]
404    fn test_prompt_builder_includes_spec_content() {
405        let builder = PromptBuilder::default();
406        let plan = create_test_plan();
407        let spec_content = "# Test Specification";
408
409        let prompt = builder
410            .build(&plan, Some(spec_content), None)
411            .expect("Failed to build prompt");
412
413        assert_eq!(prompt.context.spec_content, Some(spec_content.to_string()));
414    }
415
416    #[test]
417    fn test_prompt_builder_includes_design_content() {
418        let builder = PromptBuilder::default();
419        let plan = create_test_plan();
420        let design_content = "# Test Design";
421
422        let prompt = builder
423            .build(&plan, None, Some(design_content))
424            .expect("Failed to build prompt");
425
426        assert_eq!(
427            prompt.context.design_content,
428            Some(design_content.to_string())
429        );
430    }
431
432    #[test]
433    fn test_prompt_builder_applies_steering_rules() {
434        let builder = PromptBuilder::default();
435        let plan = create_test_plan();
436
437        let prompt = builder
438            .build(&plan, None, None)
439            .expect("Failed to build prompt");
440
441        // Should have applied naming conventions
442        assert!(!prompt.steering_rules_applied.is_empty());
443        assert!(prompt.system_prompt.contains("snake_case"));
444    }
445
446    #[test]
447    fn test_prompt_builder_estimates_tokens() {
448        let builder = PromptBuilder::default();
449        let plan = create_test_plan();
450
451        let prompt = builder
452            .build(&plan, None, None)
453            .expect("Failed to build prompt");
454
455        // Token estimate should be reasonable
456        assert!(prompt.estimated_tokens > 0);
457        assert!(prompt.estimated_tokens < builder.max_context_tokens);
458    }
459
460    #[test]
461    fn test_prompt_builder_respects_token_budget() {
462        let mut builder = PromptBuilder::default();
463        builder.max_context_tokens = 10; // Very small budget
464
465        let plan = create_test_plan();
466
467        let result = builder.build(&plan, None, None);
468
469        // Should fail due to token budget
470        assert!(result.is_err());
471    }
472
473    #[test]
474    fn test_steering_rules_has_defaults() {
475        let builder = PromptBuilder::default();
476
477        let rules = builder
478            .load_steering_rules()
479            .expect("Failed to load steering rules");
480
481        // Should have default naming conventions
482        assert!(!rules.naming_conventions.is_empty());
483        assert!(rules.naming_conventions.contains_key("rust"));
484        assert!(rules.naming_conventions.contains_key("typescript"));
485    }
486
487    #[test]
488    fn test_system_prompt_includes_standards() {
489        let builder = PromptBuilder::default();
490        let rules = SteeringRules {
491            naming_conventions: [("rust".to_string(), "snake_case".to_string())]
492                .iter()
493                .cloned()
494                .collect(),
495            code_quality_standards: vec!["Zero warnings".to_string()],
496            documentation_requirements: vec!["Doc comments required".to_string()],
497            error_handling_patterns: vec!["Use Result types".to_string()],
498            testing_requirements: vec!["Unit tests required".to_string()],
499        };
500
501        let system_prompt = builder
502            .build_system_prompt(&rules)
503            .expect("Failed to build system prompt");
504
505        assert!(system_prompt.contains("Zero warnings"));
506        assert!(system_prompt.contains("Doc comments required"));
507        assert!(system_prompt.contains("Use Result types"));
508        assert!(system_prompt.contains("snake_case"));
509    }
510
511    // ========================================================================
512    // Unit Tests for PathResolver Usage in PromptBuilder
513    // **Feature: ricecoder-path-resolution, Tests for Requirements 4.1, 4.2**
514    // ========================================================================
515
516    #[test]
517    fn test_prompt_builder_loads_steering_rules_from_correct_location() {
518        // Test that PromptBuilder loads steering rules from the correct location
519        // using PathResolver
520        let builder = PromptBuilder::default();
521
522        let rules = builder
523            .load_steering_rules()
524            .expect("Failed to load steering rules");
525
526        // Should have loaded default rules
527        assert!(!rules.naming_conventions.is_empty());
528        // Note: code_quality_standards may be empty if already loaded from project/global
529        // The important thing is that naming_conventions are loaded
530    }
531
532    #[test]
533    fn test_prompt_builder_path_resolution_with_environment_variables() {
534        // Test that path resolution respects environment variables
535        // Save original RICECODER_HOME if it exists
536        let original = std::env::var("RICECODER_HOME").ok();
537
538        // Set a test environment variable
539        std::env::set_var("RICECODER_HOME", "/tmp/test-ricecoder");
540
541        // Create builder and load steering rules
542        let builder = PromptBuilder::default();
543        let rules = builder
544            .load_steering_rules()
545            .expect("Failed to load steering rules");
546
547        // Should still have loaded rules (even if from different location)
548        assert!(!rules.naming_conventions.is_empty());
549
550        // Restore original
551        if let Some(orig) = original {
552            std::env::set_var("RICECODER_HOME", orig);
553        } else {
554            std::env::remove_var("RICECODER_HOME");
555        }
556    }
557
558    #[test]
559    fn test_prompt_builder_path_resolution_without_environment_variables() {
560        // Test that path resolution works without environment variables
561        let original = std::env::var("RICECODER_HOME").ok();
562
563        // Ensure RICECODER_HOME is not set
564        std::env::remove_var("RICECODER_HOME");
565
566        // Create builder and load steering rules
567        let builder = PromptBuilder::default();
568        let rules = builder
569            .load_steering_rules()
570            .expect("Failed to load steering rules");
571
572        // Should have loaded default rules
573        assert!(!rules.naming_conventions.is_empty());
574
575        // Restore original
576        if let Some(orig) = original {
577            std::env::set_var("RICECODER_HOME", orig);
578        }
579    }
580
581    #[test]
582    fn test_prompt_builder_error_handling_for_missing_home_directory() {
583        // Test that error handling works gracefully when home directory is missing
584        // This is a defensive test - in practice, home directory should always exist
585        let builder = PromptBuilder::default();
586
587        // Even if global path resolution fails, we should still get rules
588        // (from project-level defaults)
589        let rules = builder
590            .load_steering_rules()
591            .expect("Failed to load steering rules");
592
593        // Should have at least default naming conventions
594        assert!(!rules.naming_conventions.is_empty());
595    }
596
597    #[test]
598    fn test_prompt_builder_uses_path_resolver_for_project_path() {
599        // Test that PromptBuilder uses PathResolver for project path
600        let builder = PromptBuilder::default();
601
602        // Load steering rules which internally uses PathResolver
603        let rules = builder
604            .load_steering_rules()
605            .expect("Failed to load steering rules");
606
607        // Verify that rules were loaded (indicating PathResolver was used)
608        assert!(!rules.naming_conventions.is_empty());
609        assert!(rules.naming_conventions.contains_key("rust"));
610    }
611
612    #[test]
613    fn test_prompt_builder_steering_rules_consistency() {
614        // Test that loading steering rules multiple times returns consistent results
615        let builder = PromptBuilder::default();
616
617        let rules1 = builder
618            .load_steering_rules()
619            .expect("Failed to load steering rules");
620        let rules2 = builder
621            .load_steering_rules()
622            .expect("Failed to load steering rules");
623
624        // Both should have the same naming conventions
625        assert_eq!(rules1.naming_conventions, rules2.naming_conventions);
626        assert_eq!(rules1.code_quality_standards, rules2.code_quality_standards);
627    }
628
629    #[test]
630    fn test_prompt_builder_default_naming_conventions() {
631        // Test that default naming conventions are set correctly
632        let builder = PromptBuilder::default();
633
634        let rules = builder
635            .load_steering_rules()
636            .expect("Failed to load steering rules");
637
638        // Should have default naming conventions for common languages
639        assert_eq!(
640            rules.naming_conventions.get("rust"),
641            Some(&"snake_case".to_string())
642        );
643        assert_eq!(
644            rules.naming_conventions.get("typescript"),
645            Some(&"camelCase".to_string())
646        );
647        assert_eq!(
648            rules.naming_conventions.get("python"),
649            Some(&"snake_case".to_string())
650        );
651    }
652
653    #[test]
654    fn test_prompt_builder_code_quality_standards_loaded() {
655        // Test that code quality standards can be loaded
656        // Note: They may be empty if already loaded from project/global location
657        let builder = PromptBuilder::default();
658
659        let rules = builder
660            .load_steering_rules()
661            .expect("Failed to load steering rules");
662
663        // Should have naming conventions at minimum
664        assert!(!rules.naming_conventions.is_empty());
665    }
666
667    #[test]
668    fn test_prompt_builder_documentation_requirements_loaded() {
669        // Test that documentation requirements can be loaded
670        // Note: They may be empty if already loaded from project/global location
671        let builder = PromptBuilder::default();
672
673        let rules = builder
674            .load_steering_rules()
675            .expect("Failed to load steering rules");
676
677        // Should have naming conventions at minimum
678        assert!(!rules.naming_conventions.is_empty());
679    }
680
681    #[test]
682    fn test_prompt_builder_error_handling_patterns_loaded() {
683        // Test that error handling patterns can be loaded
684        // Note: They may be empty if already loaded from project/global location
685        let builder = PromptBuilder::default();
686
687        let rules = builder
688            .load_steering_rules()
689            .expect("Failed to load steering rules");
690
691        // Should have naming conventions at minimum
692        assert!(!rules.naming_conventions.is_empty());
693    }
694
695    #[test]
696    fn test_prompt_builder_testing_requirements_loaded() {
697        // Test that testing requirements can be loaded
698        // Note: They may be empty if already loaded from project/global location
699        let builder = PromptBuilder::default();
700
701        let rules = builder
702            .load_steering_rules()
703            .expect("Failed to load steering rules");
704
705        // Should have naming conventions at minimum
706        assert!(!rules.naming_conventions.is_empty());
707    }
708}