1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
14pub struct InjectionContext {
15 pub project: Option<String>,
17
18 pub language: Option<String>,
20
21 pub framework: Option<String>,
23
24 pub architecture: Option<String>,
26
27 pub style: Option<StyleGuide>,
29
30 pub surrounding_code: Option<String>,
32
33 pub available_imports: Vec<String>,
35
36 pub variables: HashMap<String, String>,
38
39 pub extra: HashMap<String, serde_json::Value>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub struct StyleGuide {
46 pub indent: IndentStyle,
48
49 pub max_line_length: Option<usize>,
51
52 pub semicolons: Option<bool>,
54
55 pub quote_style: Option<QuoteStyle>,
57
58 pub naming_convention: Option<NamingConvention>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
64pub enum IndentStyle {
65 Spaces(u8),
67 Tabs,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
73#[serde(rename_all = "snake_case")]
74pub enum QuoteStyle {
75 Single,
76 Double,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
81#[serde(rename_all = "snake_case")]
82pub enum NamingConvention {
83 CamelCase,
84 PascalCase,
85 SnakeCase,
86 KebabCase,
87}
88
89impl Hash for InjectionContext {
90 fn hash<H: Hasher>(&self, state: &mut H) {
91 self.project.hash(state);
92 self.language.hash(state);
93 self.framework.hash(state);
94 self.architecture.hash(state);
95 self.style.hash(state);
96 self.surrounding_code.hash(state);
97 self.available_imports.hash(state);
98
99 let mut vars: Vec<_> = self.variables.iter().collect();
100 vars.sort_by_key(|(k, _)| *k);
101 for (k, v) in vars {
102 k.hash(state);
103 v.hash(state);
104 }
105
106 let mut extra_sorted: Vec<_> = self.extra.iter().collect();
107 extra_sorted.sort_by_key(|(k, _)| *k);
108 for (k, v) in extra_sorted {
109 k.hash(state);
110 serde_json::to_string(v).unwrap_or_default().hash(state);
111 }
112 }
113}
114
115impl InjectionContext {
116 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn with_project(mut self, project: impl Into<String>) -> Self {
123 self.project = Some(project.into());
124 self
125 }
126
127 pub fn with_language(mut self, language: impl Into<String>) -> Self {
129 self.language = Some(language.into());
130 self
131 }
132
133 pub fn with_framework(mut self, framework: impl Into<String>) -> Self {
135 self.framework = Some(framework.into());
136 self
137 }
138
139 pub fn with_architecture(mut self, architecture: impl Into<String>) -> Self {
141 self.architecture = Some(architecture.into());
142 self
143 }
144
145 pub fn with_style(mut self, style: StyleGuide) -> Self {
147 self.style = Some(style);
148 self
149 }
150
151 pub fn with_surrounding_code(mut self, code: impl Into<String>) -> Self {
153 self.surrounding_code = Some(code.into());
154 self
155 }
156
157 pub fn add_import(mut self, import: impl Into<String>) -> Self {
159 self.available_imports.push(import.into());
160 self
161 }
162
163 pub fn set_variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
165 self.variables.insert(key.into(), value.into());
166 self
167 }
168
169 pub fn to_prompt(&self) -> String {
171 let mut parts = Vec::new();
172
173 if let Some(ref project) = self.project {
174 parts.push(format!("Project: {}", project));
175 }
176
177 if let Some(ref lang) = self.language {
178 parts.push(format!("Language: {}", lang));
179 }
180
181 if let Some(ref fw) = self.framework {
182 parts.push(format!("Framework: {}", fw));
183 }
184
185 if let Some(ref arch) = self.architecture {
186 parts.push(format!("Architecture: {}", arch));
187 }
188
189 if let Some(ref style) = self.style {
190 let mut style_parts = Vec::new();
191 match &style.indent {
192 IndentStyle::Spaces(n) => style_parts.push(format!("{} spaces indent", n)),
193 IndentStyle::Tabs => style_parts.push("tabs indent".to_string()),
194 }
195 if let Some(max) = style.max_line_length {
196 style_parts.push(format!("max {} chars per line", max));
197 }
198 if !style_parts.is_empty() {
199 parts.push(format!("Style: {}", style_parts.join(", ")));
200 }
201 }
202
203 if !self.available_imports.is_empty() {
204 parts.push(format!("Available imports: {}", self.available_imports.join(", ")));
205 }
206
207 if let Some(ref code) = self.surrounding_code {
208 parts.push(format!("Surrounding code:\n```\n{}\n```", code));
209 }
210
211 parts.join("\n")
212 }
213}
214
215impl Default for StyleGuide {
216 fn default() -> Self {
217 Self {
218 indent: IndentStyle::Spaces(4),
219 max_line_length: Some(100),
220 semicolons: None,
221 quote_style: None,
222 naming_convention: None,
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_context_builder() {
233 let ctx = InjectionContext::new()
234 .with_project("my-app")
235 .with_language("typescript")
236 .with_framework("react");
237
238 assert_eq!(ctx.project, Some("my-app".to_string()));
239 assert_eq!(ctx.language, Some("typescript".to_string()));
240 assert_eq!(ctx.framework, Some("react".to_string()));
241 }
242
243 #[test]
244 fn test_context_to_prompt() {
245 let ctx = InjectionContext::new()
246 .with_project("test")
247 .with_language("rust");
248
249 let prompt = ctx.to_prompt();
250 assert!(prompt.contains("Project: test"));
251 assert!(prompt.contains("Language: rust"));
252 }
253}