use anyhow::{Context, Result};
use serde::Serialize;
use std::collections::HashMap;
use std::path::PathBuf;
use crate::configuration::Code2PromptConfig;
use crate::git::{get_git_diff, get_git_diff_between_branches, get_git_log};
use crate::path::{FileEntry, display_name, traverse_directory, wrap_code_block};
use crate::selection::SelectionEngine;
use crate::template::{OutputFormat, handlebars_setup, render_template};
use crate::tokenizer::{TokenizerType, count_tokens};
#[derive(Debug, Clone)]
pub struct Code2PromptSession {
pub config: Code2PromptConfig,
pub selection_engine: SelectionEngine,
pub data: SessionData,
}
#[derive(Debug, Default, Clone)]
pub struct SessionData {
pub absolute_code_path: Option<String>,
pub source_tree: Option<String>,
pub files: Option<Vec<FileEntry>>,
pub stats: Option<serde_json::Value>,
pub git_diff: Option<String>,
pub git_diff_branch: Option<String>,
pub git_log_branch: Option<String>,
}
#[derive(Serialize)]
pub struct TemplateContext<'a> {
pub absolute_code_path: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_tree: &'a Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub files: Option<&'a [FileEntry]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub git_diff: &'a Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub git_diff_branch: &'a Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub git_log_branch: &'a Option<String>,
#[serde(flatten)]
pub user_variables: &'a HashMap<String, String>,
}
#[derive(Debug)]
pub struct RenderedPrompt {
pub prompt: String,
pub directory_name: String,
pub token_count: usize,
pub model_info: &'static str,
pub files: Vec<String>,
}
impl Code2PromptSession {
pub fn new(config: Code2PromptConfig) -> Self {
let selection_engine = SelectionEngine::new(
config.include_patterns.clone(),
config.exclude_patterns.clone(),
);
Self {
selection_engine,
config,
data: SessionData::default(),
}
}
pub fn add_include_pattern(&mut self, pattern: String) -> &mut Self {
self.config.include_patterns.push(pattern);
self.selection_engine = SelectionEngine::new(
self.config.include_patterns.clone(),
self.config.exclude_patterns.clone(),
);
self
}
pub fn add_exclude_pattern(&mut self, pattern: String) -> &mut Self {
self.config.exclude_patterns.push(pattern);
self.selection_engine = SelectionEngine::new(
self.config.include_patterns.clone(),
self.config.exclude_patterns.clone(),
);
self
}
pub fn select_file(&mut self, path: PathBuf) -> &mut Self {
let relative_path = if path.is_absolute() {
path.strip_prefix(&self.config.path)
.unwrap_or(&path)
.to_path_buf()
} else {
path
};
self.selection_engine.include_file(relative_path);
self
}
pub fn deselect_file(&mut self, path: PathBuf) -> &mut Self {
let relative_path = if path.is_absolute() {
path.strip_prefix(&self.config.path)
.unwrap_or(&path)
.to_path_buf()
} else {
path
};
self.selection_engine.exclude_file(relative_path);
self
}
pub fn toggle_file_selection(&mut self, path: PathBuf) -> &mut Self {
let relative_path = if path.is_absolute() {
path.strip_prefix(&self.config.path)
.unwrap_or(&path)
.to_path_buf()
} else {
path
};
self.selection_engine.toggle_file(relative_path);
self
}
pub fn is_file_selected(&mut self, path: &std::path::Path) -> bool {
let relative_path = if path.is_absolute() {
path.strip_prefix(&self.config.path).unwrap_or(path)
} else {
path
};
self.selection_engine.is_selected(relative_path)
}
pub fn get_selected_files(&mut self) -> Result<Vec<PathBuf>> {
Ok(self
.selection_engine
.get_selected_files(&self.config.path)?)
}
pub fn clear_user_actions(&mut self) -> &mut Self {
self.selection_engine.clear_user_actions();
self
}
pub fn has_user_actions(&self) -> bool {
self.selection_engine.has_user_actions()
}
pub fn load_codebase(&mut self) -> Result<()> {
let (tree, files) = traverse_directory(&self.config, Some(&mut self.selection_engine))
.with_context(|| "Failed to traverse directory")?;
self.data.absolute_code_path = Some(display_name(&self.config.path));
self.data.source_tree = Some(tree);
self.data.files = Some(files);
Ok(())
}
pub fn load_git_diff(&mut self) -> Result<()> {
let diff = get_git_diff(&self.config.path)?;
self.data.git_diff = Some(diff);
Ok(())
}
pub fn load_git_diff_between_branches(&mut self) -> Result<()> {
if let Some((b1, b2)) = &self.config.diff_branches {
let diff = get_git_diff_between_branches(&self.config.path, b1, b2)?;
self.data.git_diff_branch = Some(diff);
}
Ok(())
}
pub fn load_git_log_between_branches(&mut self) -> Result<()> {
if let Some((b1, b2)) = &self.config.log_branches {
let log_output = get_git_log(&self.config.path, b1, b2)?;
self.data.git_log_branch = Some(log_output);
}
Ok(())
}
pub fn build_template_data(&self) -> TemplateContext<'_> {
TemplateContext {
absolute_code_path: self.data.absolute_code_path.as_deref().unwrap_or("unknown"),
source_tree: &self.data.source_tree,
files: self.data.files.as_deref(),
git_diff: &self.data.git_diff,
git_diff_branch: &self.data.git_diff_branch,
git_log_branch: &self.data.git_log_branch,
user_variables: &self.config.user_variables,
}
}
pub fn render_prompt(&self, template_context: &TemplateContext) -> Result<RenderedPrompt> {
let mut template_str = self.config.template_str.clone();
let mut template_name = self.config.template_name.clone();
if self.config.template_str.is_empty() {
template_str = match self.config.output_format {
OutputFormat::Markdown => include_str!("./default_template_md.hbs").to_string(),
OutputFormat::Xml | OutputFormat::Json => {
include_str!("./default_template_xml.hbs").to_string()
}
};
template_name = match self.config.output_format {
OutputFormat::Markdown => "markdown".to_string(),
OutputFormat::Xml | OutputFormat::Json => "xml".to_string(),
};
}
let handlebars = handlebars_setup(&template_str, &template_name)?;
let template_content = render_template(&handlebars, &template_name, template_context)?;
let tokenizer_type: TokenizerType = self.config.encoding;
let token_count = self.calculate_token_count_from_cache(&tokenizer_type);
let model_info = tokenizer_type.description();
let directory_name = template_context.absolute_code_path.to_string();
let files: Vec<String> = self
.data
.files
.as_ref()
.map(|files| files.iter().map(|file| file.path.clone()).collect())
.unwrap_or_default();
let final_output = match self.config.output_format {
OutputFormat::Json => {
let json_data = serde_json::json!({
"prompt": template_content,
"directory_name": directory_name.clone(),
"token_count": token_count,
"model_info": model_info,
"files": files.clone(),
});
serde_json::to_string_pretty(&json_data)?
}
_ => template_content,
};
Ok(RenderedPrompt {
prompt: final_output,
directory_name,
token_count,
model_info,
files,
})
}
fn calculate_token_count_from_cache(&self, tokenizer_type: &TokenizerType) -> usize {
let files_token_count: usize = self
.data
.files
.as_ref()
.map(|files| files.iter().map(|file| file.token_count).sum())
.unwrap_or(0);
let structural_tokens = self.calculate_structural_tokens(tokenizer_type);
files_token_count + structural_tokens
}
fn calculate_structural_tokens(&self, tokenizer_type: &TokenizerType) -> usize {
let skeleton_files: Option<Vec<FileEntry>> = self.data.files.as_ref().map(|files| {
files
.iter()
.map(|file| {
let empty_code_block = wrap_code_block(
"",
&file.extension,
self.config.line_numbers,
self.config.no_codeblock,
);
FileEntry {
path: file.path.clone(),
extension: file.extension.clone(),
code: empty_code_block,
token_count: 0, metadata: file.metadata,
mod_time: file.mod_time,
}
})
.collect()
});
let skeleton_context = TemplateContext {
absolute_code_path: self.data.absolute_code_path.as_deref().unwrap_or("unknown"),
source_tree: &self.data.source_tree,
files: skeleton_files.as_deref(),
git_diff: &self.data.git_diff,
git_diff_branch: &self.data.git_diff_branch,
git_log_branch: &self.data.git_log_branch,
user_variables: &self.config.user_variables,
};
let template_str = if self.config.template_str.is_empty() {
match self.config.output_format {
OutputFormat::Markdown => include_str!("./default_template_md.hbs").to_string(),
OutputFormat::Xml | OutputFormat::Json => {
include_str!("./default_template_xml.hbs").to_string()
}
}
} else {
self.config.template_str.clone()
};
let template_name = if self.config.template_name.is_empty() {
match self.config.output_format {
OutputFormat::Markdown => "markdown".to_string(),
OutputFormat::Xml | OutputFormat::Json => "xml".to_string(),
}
} else {
self.config.template_name.clone()
};
match handlebars_setup(&template_str, &template_name) {
Ok(handlebars) => {
match render_template(&handlebars, &template_name, &skeleton_context) {
Ok(skeleton_rendered) => count_tokens(&skeleton_rendered, tokenizer_type),
Err(_) => {
self.fallback_structural_estimate(tokenizer_type)
}
}
}
Err(_) => {
self.fallback_structural_estimate(tokenizer_type)
}
}
}
fn fallback_structural_estimate(&self, tokenizer_type: &TokenizerType) -> usize {
let mut total_chars = 0;
if let Some(tree) = &self.data.source_tree {
total_chars += tree.len();
}
if let Some(diff) = &self.data.git_diff {
total_chars += diff.len();
}
if let Some(diff_branch) = &self.data.git_diff_branch {
total_chars += diff_branch.len();
}
if let Some(log_branch) = &self.data.git_log_branch {
total_chars += log_branch.len();
}
let estimated = (total_chars / 4) + 100;
if total_chars < 10000 {
let combined = format!(
"{}{}{}{}",
self.data.source_tree.as_deref().unwrap_or(""),
self.data.git_diff.as_deref().unwrap_or(""),
self.data.git_diff_branch.as_deref().unwrap_or(""),
self.data.git_log_branch.as_deref().unwrap_or("")
);
count_tokens(&combined, tokenizer_type)
} else {
estimated
}
}
pub fn generate_prompt(&mut self) -> Result<RenderedPrompt> {
self.load_codebase()?;
if self.config.diff_enabled {
match self.load_git_diff() {
Ok(_) => {}
Err(e) => log::warn!("Git diff could not be loaded: {}", e),
}
}
if self.config.diff_branches.is_some() {
match self.load_git_diff_between_branches() {
Ok(_) => {}
Err(e) => log::warn!("Git branch diff could not be loaded: {}", e),
}
}
if self.config.log_branches.is_some() {
match self.load_git_log_between_branches() {
Ok(_) => {}
Err(e) => log::warn!("Git branch log could not be loaded: {}", e),
}
}
let template_data = self.build_template_data();
let rendered = self.render_prompt(&template_data)?;
Ok(rendered)
}
}