use crate::core::{
BackendType, CodeGenBackend, LoweredGraph, OptimizedKernel, TargetSpecification,
};
use crate::{cpp_codegen::CppCodeGen, python_codegen::PythonCodeGen, FxGraph};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::Result;
pub struct CodeGenerator {
backends: HashMap<String, Box<dyn CodeGenBackend>>,
cache: Arc<RwLock<HashMap<String, String>>>,
}
impl Default for CodeGenerator {
fn default() -> Self {
let mut generator = Self {
backends: HashMap::new(),
cache: Arc::new(RwLock::new(HashMap::new())),
};
generator.register_default_backends();
generator
}
}
impl CodeGenerator {
pub fn new() -> Self {
Self::default()
}
fn register_default_backends(&mut self) {
self.add_backend("python".to_string(), PythonCodeGen::new());
self.add_backend("cpp".to_string(), CppCodeGen::new());
}
pub fn add_backend<T: CodeGenBackend + 'static>(&mut self, name: String, backend: T) {
self.backends.insert(name, Box::new(backend));
}
pub fn generate_code(&self, graph: &FxGraph, target: &str) -> Result<String> {
let cache_key = format!("{}_{}", target, self.graph_hash(graph));
if let Ok(cache) = self.cache.read() {
if let Some(cached_code) = cache.get(&cache_key) {
return Ok(cached_code.clone());
}
}
if let Some(backend) = self.backends.get(target) {
let code = backend.generate(graph)?;
if let Ok(mut cache) = self.cache.write() {
cache.insert(cache_key, code.clone());
}
Ok(code)
} else {
Err(torsh_core::Error::InvalidArgument(format!(
"Unknown target: {}",
target
)))
}
}
pub fn available_targets(&self) -> Vec<&String> {
self.backends.keys().collect()
}
pub fn get_file_extension(&self, target: &str) -> Option<&str> {
self.backends
.get(target)
.map(|backend| backend.file_extension())
}
pub fn get_language_name(&self, target: &str) -> Option<&str> {
self.backends
.get(target)
.map(|backend| backend.language_name())
}
pub fn generate_optimized_kernels(
&self,
graph: &FxGraph,
target_spec: &TargetSpecification,
) -> Result<Vec<OptimizedKernel>> {
Ok(vec![])
}
pub fn lower_to_backend(&self, graph: &FxGraph, backend: BackendType) -> Result<LoweredGraph> {
Ok(LoweredGraph {
nodes: vec![],
edges: vec![],
backend_type: backend,
})
}
pub fn generate_target_specific(
&self,
graph: &FxGraph,
target_spec: &TargetSpecification,
) -> Result<String> {
let target = match target_spec.device {
crate::core::TargetDevice::CPU => "cpp",
crate::core::TargetDevice::CUDA => "cpp", _ => "python", };
self.generate_code(graph, target)
}
pub fn create_lazy_compiler(&self, graph: &FxGraph, target: &str) -> Result<LazyCompiler> {
LazyCompiler::new(graph.clone(), target.to_string())
}
pub fn clear_cache(&self) -> Result<()> {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
Ok(())
}
pub fn cache_stats(&self) -> CacheStats {
if let Ok(cache) = self.cache.read() {
CacheStats {
entries: cache.len(),
total_size_bytes: cache.values().map(|v| v.len()).sum(),
}
} else {
CacheStats {
entries: 0,
total_size_bytes: 0,
}
}
}
fn graph_hash(&self, _graph: &FxGraph) -> u64 {
42
}
}
pub struct LazyCompiler {
graph: FxGraph,
target: String,
compiled_code: Option<Arc<CompiledCode>>,
compile_on_demand: bool,
}
impl LazyCompiler {
pub fn new(graph: FxGraph, target: String) -> Result<Self> {
Ok(Self {
graph,
target,
compiled_code: None,
compile_on_demand: true,
})
}
pub fn set_compile_on_demand(&mut self, enabled: bool) {
self.compile_on_demand = enabled;
}
pub fn get_compiled_code(&self) -> Result<Arc<CompiledCode>> {
if let Some(ref code) = self.compiled_code {
Ok(code.clone())
} else if self.compile_on_demand {
let compiled = Arc::new(CompiledCode {
source: "// Compiled code".to_string(),
target: self.target.clone(),
is_valid: true,
});
Ok(compiled)
} else {
Err(torsh_core::Error::InvalidState(
"Code not compiled and compile_on_demand is disabled".into(),
))
}
}
}
#[derive(Debug, Clone)]
pub struct CompiledCode {
pub source: String,
pub target: String,
pub is_valid: bool,
}
impl CompiledCode {
pub fn is_valid(&self) -> bool {
self.is_valid && !self.source.is_empty()
}
pub fn source(&self) -> &str {
&self.source
}
pub fn target(&self) -> &str {
&self.target
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entries: usize,
pub total_size_bytes: usize,
}
impl CacheStats {
pub fn average_entry_size(&self) -> usize {
if self.entries > 0 {
self.total_size_bytes / self.entries
} else {
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_generator_creation() {
let generator = CodeGenerator::new();
let targets = generator.available_targets();
assert!(targets.len() >= 2);
assert!(targets.contains(&&"python".to_string()));
assert!(targets.contains(&&"cpp".to_string()));
}
#[test]
fn test_get_file_extension() {
let generator = CodeGenerator::new();
assert_eq!(generator.get_file_extension("python"), Some("py"));
assert_eq!(generator.get_file_extension("cpp"), Some("cpp"));
assert_eq!(generator.get_file_extension("unknown"), None);
}
#[test]
fn test_get_language_name() {
let generator = CodeGenerator::new();
assert_eq!(generator.get_language_name("python"), Some("Python"));
assert_eq!(generator.get_language_name("cpp"), Some("C++"));
assert_eq!(generator.get_language_name("unknown"), None);
}
#[test]
fn test_cache_stats() {
let generator = CodeGenerator::new();
let stats = generator.cache_stats();
assert_eq!(stats.entries, 0);
assert_eq!(stats.total_size_bytes, 0);
assert_eq!(stats.average_entry_size(), 0);
}
#[test]
fn test_clear_cache() {
let generator = CodeGenerator::new();
assert!(generator.clear_cache().is_ok());
}
#[test]
fn test_lazy_compiler_creation() {
assert!(true);
}
#[test]
fn test_compiled_code_validity() {
let code = CompiledCode {
source: "def test(): pass".to_string(),
target: "python".to_string(),
is_valid: true,
};
assert!(code.is_valid());
assert_eq!(code.source(), "def test(): pass");
assert_eq!(code.target(), "python");
}
}