use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct InjectionContext {
pub project: Option<String>,
pub language: Option<String>,
pub framework: Option<String>,
pub architecture: Option<String>,
pub style: Option<StyleGuide>,
pub surrounding_code: Option<String>,
pub available_imports: Vec<String>,
pub variables: HashMap<String, String>,
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct StyleGuide {
pub indent: IndentStyle,
pub max_line_length: Option<usize>,
pub semicolons: Option<bool>,
pub quote_style: Option<QuoteStyle>,
pub naming_convention: Option<NamingConvention>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum IndentStyle {
Spaces(u8),
Tabs,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum QuoteStyle {
Single,
Double,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum NamingConvention {
CamelCase,
PascalCase,
SnakeCase,
KebabCase,
}
impl Hash for InjectionContext {
fn hash<H: Hasher>(&self, state: &mut H) {
self.project.hash(state);
self.language.hash(state);
self.framework.hash(state);
self.architecture.hash(state);
self.style.hash(state);
self.surrounding_code.hash(state);
self.available_imports.hash(state);
let mut vars: Vec<_> = self.variables.iter().collect();
vars.sort_by_key(|(k, _)| *k);
for (k, v) in vars {
k.hash(state);
v.hash(state);
}
let mut extra_sorted: Vec<_> = self.extra.iter().collect();
extra_sorted.sort_by_key(|(k, _)| *k);
for (k, v) in extra_sorted {
k.hash(state);
serde_json::to_string(v).unwrap_or_default().hash(state);
}
}
}
impl InjectionContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_project(mut self, project: impl Into<String>) -> Self {
self.project = Some(project.into());
self
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_framework(mut self, framework: impl Into<String>) -> Self {
self.framework = Some(framework.into());
self
}
pub fn with_architecture(mut self, architecture: impl Into<String>) -> Self {
self.architecture = Some(architecture.into());
self
}
pub fn with_style(mut self, style: StyleGuide) -> Self {
self.style = Some(style);
self
}
pub fn with_surrounding_code(mut self, code: impl Into<String>) -> Self {
self.surrounding_code = Some(code.into());
self
}
pub fn add_import(mut self, import: impl Into<String>) -> Self {
self.available_imports.push(import.into());
self
}
pub fn set_variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.variables.insert(key.into(), value.into());
self
}
pub fn to_prompt(&self) -> String {
let mut parts = Vec::new();
if let Some(ref project) = self.project {
parts.push(format!("Project: {}", project));
}
if let Some(ref lang) = self.language {
parts.push(format!("Language: {}", lang));
}
if let Some(ref fw) = self.framework {
parts.push(format!("Framework: {}", fw));
}
if let Some(ref arch) = self.architecture {
parts.push(format!("Architecture: {}", arch));
}
if let Some(ref style) = self.style {
let mut style_parts = Vec::new();
match &style.indent {
IndentStyle::Spaces(n) => style_parts.push(format!("{} spaces indent", n)),
IndentStyle::Tabs => style_parts.push("tabs indent".to_string()),
}
if let Some(max) = style.max_line_length {
style_parts.push(format!("max {} chars per line", max));
}
if !style_parts.is_empty() {
parts.push(format!("Style: {}", style_parts.join(", ")));
}
}
if !self.available_imports.is_empty() {
parts.push(format!("Available imports: {}", self.available_imports.join(", ")));
}
if let Some(ref code) = self.surrounding_code {
parts.push(format!("Surrounding code:\n```\n{}\n```", code));
}
parts.join("\n")
}
}
impl Default for StyleGuide {
fn default() -> Self {
Self {
indent: IndentStyle::Spaces(4),
max_line_length: Some(100),
semicolons: None,
quote_style: None,
naming_convention: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_builder() {
let ctx = InjectionContext::new()
.with_project("my-app")
.with_language("typescript")
.with_framework("react");
assert_eq!(ctx.project, Some("my-app".to_string()));
assert_eq!(ctx.language, Some("typescript".to_string()));
assert_eq!(ctx.framework, Some("react".to_string()));
}
#[test]
fn test_context_to_prompt() {
let ctx = InjectionContext::new()
.with_project("test")
.with_language("rust");
let prompt = ctx.to_prompt();
assert!(prompt.contains("Project: test"));
assert!(prompt.contains("Language: rust"));
}
}