use std::sync::Mutex;
use mistralrs::{GgufModelBuilder, Model, RequestBuilder, StopTokens, TextMessageRole};
use tokio::runtime::Runtime;
use crate::{
deobfuscation::renamer::{
context::RenameContext, prompt, validate, RenameProvider, SmartRenameConfig,
},
Result,
};
const SYSTEM_PROMPT: &str = "\
You are a .NET identifier naming assistant. Given code context from a deobfuscated \
.NET assembly, suggest a single semantically meaningful identifier name. \
If the context says certain names are already taken, you MUST suggest a different name. \
Respond with ONLY the identifier name — no explanations, no punctuation, no quotes.";
pub struct LocalProvider {
config: SmartRenameConfig,
state: Mutex<Option<InferenceState>>,
}
struct InferenceState {
model: Model,
runtime: Runtime,
}
impl LocalProvider {
pub fn new(config: SmartRenameConfig) -> Self {
Self {
config,
state: Mutex::new(None),
}
}
fn build_chat_prompt(&self, context: &RenameContext) -> String {
let (prefix, suffix) = prompt::build_fim_prompt(context, self.config.max_phases_in_prompt);
let clean_prefix = prefix
.replace("<|fim_prefix|>", "")
.replace("<|fim_suffix|>", "");
let clean_suffix = suffix.replace("<|fim_middle|>", "");
let prompt = format!(
"Suggest a name for the identifier marked with ??? in this .NET code:\n\n\
{clean_prefix}???{clean_suffix}"
);
log::debug!("Chat prompt:\n{prompt}");
prompt
}
fn infer(&self, state: &InferenceState, context: &RenameContext) -> Result<Option<String>> {
let user_prompt = self.build_chat_prompt(context);
let stop_seqs = self.config.stop_sequences.clone();
let max_tokens = self.config.max_tokens as usize;
let temperature = self.config.temperature;
let request = RequestBuilder::new()
.add_message(TextMessageRole::System, SYSTEM_PROMPT)
.add_message(TextMessageRole::User, user_prompt)
.set_sampler_max_len(max_tokens)
.set_sampler_temperature(temperature)
.set_sampler_stop_toks(StopTokens::Seqs(stop_seqs));
let response = state
.runtime
.block_on(state.model.send_chat_request(request))
.map_err(|e| crate::Error::Deobfuscation(format!("Model inference failed: {e}")))?;
if let Some(choice) = response.choices.first() {
log::debug!(
"Model response: content={:?} finish_reason={:?}",
choice.message.content,
choice.finish_reason
);
}
let generated = response
.choices
.first()
.and_then(|c| c.message.content.as_deref())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
Ok(generated)
}
}
impl RenameProvider for LocalProvider {
fn name(&self) -> &'static str {
"LocalProvider"
}
fn initialize(&mut self) -> Result<()> {
if !self.config.model_path.exists() {
return Err(crate::Error::Deobfuscation(format!(
"Smart rename model not found: {}",
self.config.model_path.display()
)));
}
let runtime = Runtime::new().map_err(|e| {
crate::Error::Deobfuscation(format!("Failed to create tokio runtime: {e}"))
})?;
let model_path = self.config.model_path.canonicalize().map_err(|e| {
crate::Error::Deobfuscation(format!(
"Failed to resolve model path {}: {e}",
self.config.model_path.display()
))
})?;
let parent = model_path
.parent()
.unwrap_or_else(|| std::path::Path::new("."));
let filename = model_path
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("model.gguf")
.to_string();
let force_cpu = self.config.force_cpu;
let model = runtime.block_on(async {
let mut builder = GgufModelBuilder::new(parent.display().to_string(), vec![filename]);
if force_cpu {
builder = builder.with_force_cpu();
}
builder
.build()
.await
.map_err(|e| crate::Error::Deobfuscation(format!("Model load failed: {e}")))
})?;
log::info!(
"Smart rename model loaded: {}",
self.config.model_path.display()
);
*self.state.lock().unwrap() = Some(InferenceState { model, runtime });
Ok(())
}
fn suggest_name(&self, context: &RenameContext) -> Result<Option<String>> {
let kind = match context.kind {
Some(k) => k,
None => return Ok(None),
};
let guard = self.state.lock().unwrap();
let Some(ref state) = *guard else {
return Ok(None);
};
let raw_name = self.infer(state, context)?;
match raw_name {
Some(name) => Ok(validate::validate_name(
&name,
kind,
self.config.max_name_length,
)),
None => Ok(None),
}
}
fn shutdown(&mut self) -> Result<()> {
*self.state.lock().unwrap() = None;
log::info!("Smart rename model unloaded");
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::deobfuscation::renamer::{
context::{IdentifierKind, RenameContext},
providers::local::LocalProvider,
RenameProvider, SmartRenameConfig,
};
use std::path::PathBuf;
#[test]
fn test_local_provider_config() {
let config = SmartRenameConfig {
model_path: PathBuf::from("/nonexistent/model.gguf"),
max_tokens: 20,
threads: 4,
force_cpu: false,
..SmartRenameConfig::default()
};
let provider = LocalProvider::new(config);
assert_eq!(provider.name(), "LocalProvider");
}
#[test]
fn test_local_provider_uninitialized_returns_none() {
let config = SmartRenameConfig::default();
let provider = LocalProvider::new(config);
let ctx = RenameContext {
kind: Some(IdentifierKind::Method),
..Default::default()
};
let result = provider.suggest_name(&ctx).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_local_provider_no_kind_returns_none() {
let config = SmartRenameConfig::default();
let provider = LocalProvider::new(config);
let ctx = RenameContext::default();
let result = provider.suggest_name(&ctx).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_local_provider_missing_model_file() {
let config = SmartRenameConfig {
model_path: PathBuf::from("/nonexistent/model.gguf"),
..SmartRenameConfig::default()
};
let mut provider = LocalProvider::new(config);
let result = provider.initialize();
assert!(result.is_err(), "Should fail with missing model file");
}
#[test]
fn test_chat_prompt_construction() {
let config = SmartRenameConfig::default();
let provider = LocalProvider::new(config);
let ctx = RenameContext {
kind: Some(IdentifierKind::Method),
dotnet_type: Some("void".to_string()),
call_site_skeleton: Some(" File.WriteAllText(var_0, var_1);".to_string()),
..Default::default()
};
let prompt = provider.build_chat_prompt(&ctx);
assert!(prompt.contains("???"), "Should contain placeholder");
assert!(
prompt.contains("File.WriteAllText"),
"Should contain call target"
);
assert!(
!prompt.contains("<|fim_prefix|>"),
"Should not contain FIM tokens"
);
assert!(
!prompt.contains("<|fim_middle|>"),
"Should not contain FIM tokens"
);
}
#[test]
#[ignore]
fn test_local_provider_inference() {
let model_path = match std::env::var("DOTSCOPE_SMART_RENAME_MODEL") {
Ok(p) => PathBuf::from(p),
Err(_) => {
eprintln!("Skipping: DOTSCOPE_SMART_RENAME_MODEL not set");
return;
}
};
let config = SmartRenameConfig {
model_path,
max_tokens: 20,
threads: 0,
force_cpu: false,
..SmartRenameConfig::default()
};
let mut provider = LocalProvider::new(config);
provider.initialize().unwrap();
let ctx = RenameContext {
kind: Some(IdentifierKind::Method),
call_targets: vec![
"System.IO.File::ReadAllText".to_string(),
"System.Text.Encoding::GetBytes".to_string(),
],
dotnet_type: Some("byte[]".to_string()),
..Default::default()
};
let name = provider.suggest_name(&ctx).unwrap();
eprintln!("Generated name: {name:?}");
assert!(name.is_some(), "Model should produce a name");
let name = name.unwrap();
assert!(
name.chars().next().unwrap().is_ascii_uppercase(),
"Method name '{name}' should be PascalCase"
);
provider.shutdown().unwrap();
}
}