use crate::error::{McpError, McpResult};
use crate::protocol::{Range, ToolCallResult};
use crate::tools::text_response;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use windjammer::lexer::Lexer;
use windjammer::parser::{Expression, Parser, Statement, Type};
use windjammer_lsp::database::WindjammerDatabase;
#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractFunctionRequest {
pub code: String,
pub range: Range,
pub function_name: String,
#[serde(default)]
pub make_public: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractFunctionResponse {
pub success: bool,
pub refactored_code: Option<String>,
pub function_signature: Option<String>,
pub captured_variables: Option<Vec<String>>,
pub error: Option<String>,
}
pub async fn handle(
_db: Arc<Mutex<WindjammerDatabase>>,
arguments: Value,
) -> McpResult<ToolCallResult> {
let request: ExtractFunctionRequest =
serde_json::from_value(arguments).map_err(|e| McpError::ValidationError {
field: "arguments".to_string(),
message: e.to_string(),
})?;
let mut lexer = Lexer::new(&request.code);
let tokens = lexer.tokenize_with_locations();
let mut parser = Parser::new(tokens);
let parse_result = parser.parse();
let program = match parse_result {
Ok(prog) => prog,
Err(e) => {
let response = ExtractFunctionResponse {
success: false,
refactored_code: None,
function_signature: None,
captured_variables: None,
error: Some(format!("Parse error: {}", e)),
};
return Ok(text_response(&serde_json::to_string(&response)?));
}
};
let extracted = extract_statements_in_range(&program, &request.range);
if extracted.is_empty() {
let response = ExtractFunctionResponse {
success: false,
refactored_code: None,
function_signature: None,
captured_variables: None,
error: Some("No statements found in selection".to_string()),
};
return Ok(text_response(&serde_json::to_string(&response)?));
}
let analysis = analyze_variable_usage(&extracted);
let new_function = generate_function(
&request.function_name,
&extracted,
&analysis,
request.make_public,
);
let function_call = generate_function_call(&request.function_name, &analysis);
let refactored_code = format!(
"// TODO: Implement full AST manipulation\n// New function:\n{}\n\n// Replace selection with:\n{}",
new_function, function_call
);
let response = ExtractFunctionResponse {
success: true,
refactored_code: Some(refactored_code),
function_signature: Some(format!(
"fn {}({}) -> {}",
request.function_name,
analysis
.parameters
.iter()
.map(|(name, ty)| format!("{}: {:?}", name, ty))
.collect::<Vec<_>>()
.join(", "),
if let Some(ret) = &analysis.return_type {
format!("{:?}", ret)
} else {
"()".to_string()
}
)),
captured_variables: Some(analysis.parameters.iter().map(|(n, _)| n.clone()).collect()),
error: None,
};
Ok(text_response(&serde_json::to_string(&response)?))
}
fn extract_statements_in_range(
_program: &windjammer::parser::Program,
_range: &Range,
) -> Vec<Statement> {
vec![]
}
#[derive(Debug)]
struct VariableAnalysis {
parameters: Vec<(String, Type)>,
return_type: Option<Type>,
#[allow(dead_code)]
used_variables: HashSet<String>,
#[allow(dead_code)]
defined_variables: HashSet<String>,
}
fn analyze_variable_usage(statements: &[Statement]) -> VariableAnalysis {
let mut used = HashSet::new();
let mut defined = HashSet::new();
for stmt in statements {
collect_variable_usage(stmt, &mut used, &mut defined);
}
let parameters: Vec<(String, Type)> = used
.difference(&defined)
.map(|name| (name.clone(), Type::String)) .collect();
let return_type = infer_return_type(statements);
VariableAnalysis {
parameters,
return_type,
used_variables: used,
defined_variables: defined,
}
}
fn collect_variable_usage(
stmt: &Statement,
used: &mut HashSet<String>,
defined: &mut HashSet<String>,
) {
match stmt {
Statement::Let { pattern, value, .. } => {
collect_expr_variables(value, used);
if let windjammer::parser::Pattern::Identifier(name) = pattern {
defined.insert(name.clone());
}
}
Statement::Assignment {
target,
value,
location: _,
} => {
collect_expr_variables(target, used);
collect_expr_variables(value, used);
}
Statement::Return {
value: Some(expr),
location: _,
} => {
collect_expr_variables(expr, used);
}
Statement::Expression { expr, location: _ } => {
collect_expr_variables(expr, used);
}
Statement::If {
condition,
then_block,
else_block,
location: _,
} => {
collect_expr_variables(condition, used);
for s in then_block {
collect_variable_usage(s, used, defined);
}
if let Some(else_stmts) = else_block {
for s in else_stmts {
collect_variable_usage(s, used, defined);
}
}
}
Statement::For {
pattern,
iterable,
body,
location: _,
} => {
if let windjammer::parser::Pattern::Identifier(var) = pattern {
defined.insert(var.clone());
}
collect_expr_variables(iterable, used);
for s in body {
collect_variable_usage(s, used, defined);
}
}
Statement::While {
condition,
body,
location: _,
} => {
collect_expr_variables(condition, used);
for s in body {
collect_variable_usage(s, used, defined);
}
}
Statement::Loop { body, location: _ } => {
for s in body {
collect_variable_usage(s, used, defined);
}
}
_ => {}
}
}
fn collect_expr_variables(expr: &Expression, used: &mut HashSet<String>) {
match expr {
Expression::Identifier { name, location: _ } => {
used.insert(name.clone());
}
Expression::Binary { left, right, .. } => {
collect_expr_variables(left, used);
collect_expr_variables(right, used);
}
Expression::Unary { operand, .. } => {
collect_expr_variables(operand, used);
}
Expression::Call {
function,
arguments,
location: _,
} => {
collect_expr_variables(function, used);
for (_, arg) in arguments {
collect_expr_variables(arg, used);
}
}
Expression::MethodCall {
object, arguments, ..
} => {
collect_expr_variables(object, used);
for (_, arg) in arguments {
collect_expr_variables(arg, used);
}
}
Expression::FieldAccess { object, .. } => {
collect_expr_variables(object, used);
}
Expression::Index {
object,
index,
location: _,
} => {
collect_expr_variables(object, used);
collect_expr_variables(index, used);
}
Expression::Block {
statements: stmts,
location: _,
} => {
let mut defined = HashSet::new();
for stmt in stmts {
collect_variable_usage(stmt, used, &mut defined);
}
}
_ => {}
}
}
fn infer_return_type(statements: &[Statement]) -> Option<Type> {
for stmt in statements {
if let Statement::Return {
value: Some(_expr),
location: _,
} = stmt
{
return Some(Type::String); }
}
if let Some(Statement::Expression {
expr: _expr,
location: _,
}) = statements.last()
{
return Some(Type::String); }
None
}
fn generate_function(
name: &str,
_statements: &[Statement],
analysis: &VariableAnalysis,
_make_public: bool,
) -> String {
let params = analysis
.parameters
.iter()
.map(|(name, _ty)| format!("{}: Type", name))
.collect::<Vec<_>>()
.join(", ");
let return_type = if analysis.return_type.is_some() {
" -> Type"
} else {
""
};
format!(
"fn {}({}){} {{\n // Extracted code here\n}}",
name, params, return_type
)
}
fn generate_function_call(name: &str, analysis: &VariableAnalysis) -> String {
let args = analysis
.parameters
.iter()
.map(|(name, _)| name.clone())
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", name, args)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_extract_function_basic() {
let db = Arc::new(Mutex::new(WindjammerDatabase::new()));
let args = serde_json::json!({
"code": "fn main() {\n let x = 1;\n let y = 2;\n println!(\"{}\", x + y);\n}",
"range": {
"start": { "line": 1, "column": 4 },
"end": { "line": 2, "column": 17 }
},
"function_name": "calculate_sum"
});
let result = handle(db, args).await;
assert!(result.is_ok());
}
}