use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TransformId(u64);
impl TransformId {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn id(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TransformType {
Jit,
Vmap,
Pmap,
Grad,
ValueAndGrad,
Custom,
}
#[derive(Debug, Clone)]
pub struct TransformMetadata {
pub id: TransformId,
pub transform_type: TransformType,
pub name: String,
pub application_count: usize,
pub is_cached: bool,
pub custom_metadata: HashMap<String, String>,
}
impl TransformMetadata {
pub fn new(id: TransformId, transform_type: TransformType, name: impl Into<String>) -> Self {
Self {
id,
transform_type,
name: name.into(),
application_count: 0,
is_cached: false,
custom_metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.custom_metadata.insert(key.into(), value.into());
self
}
pub fn cached(mut self) -> Self {
self.is_cached = true;
self
}
pub fn increment_count(&mut self) {
self.application_count += 1;
}
}
pub trait Jittable<Input, Output> {
fn jit_apply(&self, input: Input) -> Output;
fn jit_metadata(&self) -> &TransformMetadata;
fn invalidate_cache(&mut self);
}
pub trait Vectorizable<Input, Output> {
fn vmap_apply(&self, input: Input, in_dim: usize) -> Output;
fn vmap_dim(&self) -> usize;
fn vmap_metadata(&self) -> &TransformMetadata;
}
pub trait Parallelizable<Input, Output> {
fn pmap_apply(&self, input: Input, devices: &[usize]) -> Output;
fn pmap_degree(&self) -> usize;
fn pmap_metadata(&self) -> &TransformMetadata;
}
#[derive(Debug, Clone)]
struct JitCacheEntry<Output> {
output: Output,
hit_count: usize,
last_access: std::time::Instant,
}
pub struct JitTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
Input: Clone + Hash + Eq,
Output: Clone,
{
func: F,
metadata: Arc<RwLock<TransformMetadata>>,
cache: Arc<RwLock<HashMap<u64, JitCacheEntry<Output>>>>,
max_cache_size: usize,
_phantom: PhantomData<Input>,
}
impl<F, Input, Output> JitTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
Input: Clone + Hash + Eq,
Output: Clone,
{
pub fn new(func: F) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
func,
metadata: Arc::new(RwLock::new(
TransformMetadata::new(TransformId::new(id), TransformType::Jit, "jit").cached(),
)),
cache: Arc::new(RwLock::new(HashMap::new())),
max_cache_size: 1000, _phantom: PhantomData,
}
}
pub fn with_cache_size(func: F, max_size: usize) -> Self {
let mut transform = Self::new(func);
transform.max_cache_size = max_size;
transform
}
pub fn apply(&self, input: Input) -> Output
where
Input: Hash,
{
let mut hasher = std::collections::hash_map::DefaultHasher::new();
input.hash(&mut hasher);
let hash = hasher.finish();
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
if let Some(entry) = cache.get_mut(&hash) {
entry.hit_count += 1;
entry.last_access = std::time::Instant::now();
self.metadata
.write()
.expect("lock should not be poisoned")
.increment_count();
return entry.output.clone();
}
}
let output = (self.func)(input.clone());
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
if self.max_cache_size > 0 && cache.len() >= self.max_cache_size {
if let Some((&oldest_key, _)) =
cache.iter().min_by_key(|(_, entry)| entry.last_access)
{
cache.remove(&oldest_key);
}
}
cache.insert(
hash,
JitCacheEntry {
output: output.clone(),
hit_count: 0,
last_access: std::time::Instant::now(),
},
);
}
self.metadata
.write()
.expect("lock should not be poisoned")
.increment_count();
output
}
pub fn cache_stats(&self) -> CacheStats {
let cache = self.cache.read().expect("lock should not be poisoned");
let total_hits: usize = cache.values().map(|entry| entry.hit_count).sum();
CacheStats {
size: cache.len(),
total_hits,
total_misses: self
.metadata
.read()
.expect("lock should not be poisoned")
.application_count
- total_hits,
max_size: self.max_cache_size,
}
}
pub fn clear_cache(&mut self) {
self.cache
.write()
.expect("lock should not be poisoned")
.clear();
}
pub fn metadata(&self) -> TransformMetadata {
self.metadata
.read()
.expect("lock should not be poisoned")
.clone()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub total_hits: usize,
pub total_misses: usize,
pub max_size: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.total_hits + self.total_misses;
if total == 0 {
0.0
} else {
self.total_hits as f64 / total as f64
}
}
pub fn is_full(&self) -> bool {
self.max_size > 0 && self.size >= self.max_size
}
}
pub struct VmapTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
{
func: F,
in_dim: usize,
out_dim: usize,
metadata: Arc<RwLock<TransformMetadata>>,
_phantom: PhantomData<(Input, Output)>,
}
impl<F, Input, Output> VmapTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
{
pub fn new(func: F, in_dim: usize) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
func,
in_dim,
out_dim: in_dim, metadata: Arc::new(RwLock::new(TransformMetadata::new(
TransformId::new(id),
TransformType::Vmap,
"vmap",
))),
_phantom: PhantomData,
}
}
pub fn with_out_dim(func: F, in_dim: usize, out_dim: usize) -> Self {
let mut transform = Self::new(func, in_dim);
transform.out_dim = out_dim;
transform
}
pub fn in_dim(&self) -> usize {
self.in_dim
}
pub fn out_dim(&self) -> usize {
self.out_dim
}
pub fn metadata(&self) -> TransformMetadata {
self.metadata
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn apply_marker(&self) -> &F {
&self.func
}
}
pub struct PmapTransform<F, Input, Output>
where
F: Fn(Input) -> Output + Send + Sync,
Input: Send,
Output: Send,
{
func: Arc<F>,
num_devices: usize,
metadata: Arc<RwLock<TransformMetadata>>,
_phantom: PhantomData<(Input, Output)>,
}
impl<F, Input, Output> PmapTransform<F, Input, Output>
where
F: Fn(Input) -> Output + Send + Sync,
Input: Send,
Output: Send,
{
pub fn new(func: F, num_devices: usize) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
func: Arc::new(func),
num_devices,
metadata: Arc::new(RwLock::new(TransformMetadata::new(
TransformId::new(id),
TransformType::Pmap,
"pmap",
))),
_phantom: PhantomData,
}
}
pub fn num_devices(&self) -> usize {
self.num_devices
}
pub fn metadata(&self) -> TransformMetadata {
self.metadata
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn func(&self) -> &Arc<F> {
&self.func
}
}
pub struct GradTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
{
func: F,
argnums: Vec<usize>,
metadata: Arc<RwLock<TransformMetadata>>,
_phantom: PhantomData<(Input, Output)>,
}
impl<F, Input, Output> GradTransform<F, Input, Output>
where
F: Fn(Input) -> Output,
{
pub fn new(func: F, argnums: Vec<usize>) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
func,
argnums,
metadata: Arc::new(RwLock::new(TransformMetadata::new(
TransformId::new(id),
TransformType::Grad,
"grad",
))),
_phantom: PhantomData,
}
}
pub fn argnums(&self) -> &[usize] {
&self.argnums
}
pub fn metadata(&self) -> TransformMetadata {
self.metadata
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn func(&self) -> &F {
&self.func
}
}
pub struct ComposedTransform<F1, F2, Input, Intermediate, Output>
where
F1: Fn(Input) -> Intermediate,
F2: Fn(Intermediate) -> Output,
{
first: F1,
second: F2,
metadata: Arc<RwLock<TransformMetadata>>,
_phantom: PhantomData<(Input, Intermediate, Output)>,
}
impl<F1, F2, Input, Intermediate, Output> ComposedTransform<F1, F2, Input, Intermediate, Output>
where
F1: Fn(Input) -> Intermediate,
F2: Fn(Intermediate) -> Output,
{
pub fn new(first: F1, second: F2) -> Self {
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
first,
second,
metadata: Arc::new(RwLock::new(TransformMetadata::new(
TransformId::new(id),
TransformType::Custom,
"composed",
))),
_phantom: PhantomData,
}
}
pub fn apply(&self, input: Input) -> Output {
let intermediate = (self.first)(input);
(self.second)(intermediate)
}
pub fn metadata(&self) -> TransformMetadata {
self.metadata
.read()
.expect("lock should not be poisoned")
.clone()
}
}
pub struct TransformRegistry {
transforms: Arc<RwLock<HashMap<TransformId, TransformMetadata>>>,
}
impl TransformRegistry {
pub fn new() -> Self {
Self {
transforms: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&self, metadata: TransformMetadata) {
self.transforms
.write()
.expect("transforms lock should not be poisoned")
.insert(metadata.id, metadata);
}
pub fn get(&self, id: TransformId) -> Option<TransformMetadata> {
self.transforms
.read()
.expect("lock should not be poisoned")
.get(&id)
.cloned()
}
pub fn get_by_type(&self, transform_type: TransformType) -> Vec<TransformMetadata> {
self.transforms
.read()
.expect("transforms lock should not be poisoned")
.values()
.filter(|m| m.transform_type == transform_type)
.cloned()
.collect()
}
pub fn len(&self) -> usize {
self.transforms
.read()
.expect("lock should not be poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.transforms
.read()
.expect("lock should not be poisoned")
.is_empty()
}
pub fn clear(&self) {
self.transforms
.write()
.expect("lock should not be poisoned")
.clear();
}
}
impl Default for TransformRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jit_basic() {
let jit_fn = JitTransform::new(|x: i32| x * 2);
assert_eq!(jit_fn.apply(5), 10);
assert_eq!(jit_fn.apply(10), 20);
assert_eq!(jit_fn.apply(5), 10); }
#[test]
fn test_jit_cache_stats() {
let jit_fn = JitTransform::new(|x: i32| x * 2);
jit_fn.apply(5);
jit_fn.apply(5); jit_fn.apply(10);
jit_fn.apply(10);
let stats = jit_fn.cache_stats();
assert_eq!(stats.size, 2); assert_eq!(stats.total_hits, 2); assert_eq!(stats.total_misses, 2); assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_jit_cache_eviction() {
let jit_fn = JitTransform::with_cache_size(|x: i32| x * 2, 2);
jit_fn.apply(1);
jit_fn.apply(2);
jit_fn.apply(3);
let stats = jit_fn.cache_stats();
assert_eq!(stats.size, 2); }
#[test]
fn test_jit_clear_cache() {
let mut jit_fn = JitTransform::new(|x: i32| x * 2);
jit_fn.apply(5);
jit_fn.apply(10);
assert_eq!(jit_fn.cache_stats().size, 2);
jit_fn.clear_cache();
assert_eq!(jit_fn.cache_stats().size, 0);
}
#[test]
fn test_vmap_basic() {
let vmap_fn = VmapTransform::new(|x: f32| x * x, 0);
assert_eq!(vmap_fn.in_dim(), 0);
assert_eq!(vmap_fn.out_dim(), 0);
}
#[test]
fn test_vmap_different_dims() {
let vmap_fn = VmapTransform::with_out_dim(|x: f32| x * x, 0, 1);
assert_eq!(vmap_fn.in_dim(), 0);
assert_eq!(vmap_fn.out_dim(), 1);
}
#[test]
fn test_pmap_basic() {
let pmap_fn = PmapTransform::new(|x: i32| x * 2, 4);
assert_eq!(pmap_fn.num_devices(), 4);
}
#[test]
fn test_grad_basic() {
let grad_fn = GradTransform::new(|x: f32| x * x, vec![0]);
assert_eq!(grad_fn.argnums(), &[0]);
}
#[test]
fn test_grad_multiple_args() {
let grad_fn = GradTransform::new(|_: (f32, f32)| 0.0, vec![0, 1]);
assert_eq!(grad_fn.argnums(), &[0, 1]);
}
#[test]
fn test_composed_transform() {
let composed = ComposedTransform::new(|x: i32| x * 2, |x: i32| x + 1);
assert_eq!(composed.apply(5), 11); }
#[test]
fn test_transform_registry() {
let registry = TransformRegistry::new();
assert!(registry.is_empty());
let metadata = TransformMetadata::new(TransformId::new(1), TransformType::Jit, "test");
registry.register(metadata.clone());
assert_eq!(registry.len(), 1);
assert_eq!(
registry
.get(TransformId::new(1))
.expect("get should succeed")
.name,
"test"
);
}
#[test]
fn test_transform_registry_by_type() {
let registry = TransformRegistry::new();
registry.register(TransformMetadata::new(
TransformId::new(1),
TransformType::Jit,
"jit1",
));
registry.register(TransformMetadata::new(
TransformId::new(2),
TransformType::Vmap,
"vmap1",
));
registry.register(TransformMetadata::new(
TransformId::new(3),
TransformType::Jit,
"jit2",
));
let jit_transforms = registry.get_by_type(TransformType::Jit);
assert_eq!(jit_transforms.len(), 2);
}
#[test]
fn test_transform_metadata() {
let mut metadata = TransformMetadata::new(TransformId::new(1), TransformType::Jit, "test");
assert_eq!(metadata.application_count, 0);
metadata.increment_count();
assert_eq!(metadata.application_count, 1);
}
#[test]
fn test_transform_metadata_custom() {
let metadata = TransformMetadata::new(TransformId::new(1), TransformType::Custom, "test")
.with_metadata("key1", "value1")
.with_metadata("key2", "value2");
assert_eq!(
metadata
.custom_metadata
.get("key1")
.expect("key1 should exist"),
"value1"
);
assert_eq!(
metadata
.custom_metadata
.get("key2")
.expect("key2 should exist"),
"value2"
);
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats {
size: 5,
total_hits: 8,
total_misses: 2,
max_size: 10,
};
assert_eq!(stats.hit_rate(), 0.8);
assert!(!stats.is_full());
}
#[test]
fn test_cache_stats_full() {
let stats = CacheStats {
size: 10,
total_hits: 5,
total_misses: 5,
max_size: 10,
};
assert!(stats.is_full());
}
#[test]
fn test_transform_id_equality() {
let id1 = TransformId::new(42);
let id2 = TransformId::new(42);
let id3 = TransformId::new(43);
assert_eq!(id1, id2);
assert_ne!(id1, id3);
assert_eq!(id1.id(), 42);
}
#[test]
fn test_jit_metadata() {
let jit_fn = JitTransform::new(|x: i32| x * 2);
let metadata = jit_fn.metadata();
assert_eq!(metadata.transform_type, TransformType::Jit);
assert_eq!(metadata.name, "jit");
assert!(metadata.is_cached);
}
#[test]
fn test_vmap_metadata() {
let vmap_fn = VmapTransform::new(|x: f32| x * x, 0);
let metadata = vmap_fn.metadata();
assert_eq!(metadata.transform_type, TransformType::Vmap);
assert_eq!(metadata.name, "vmap");
}
#[test]
fn test_pmap_metadata() {
let pmap_fn = PmapTransform::new(|x: i32| x * 2, 4);
let metadata = pmap_fn.metadata();
assert_eq!(metadata.transform_type, TransformType::Pmap);
assert_eq!(metadata.name, "pmap");
}
}