use crate::{Result, VisionError};
use std::path::Path;
use torsh_tensor::Tensor;
pub struct Visualizer3D {
pub background_color: (f32, f32, f32),
pub camera_position: (f32, f32, f32),
pub camera_target: (f32, f32, f32),
}
#[derive(Debug, Clone)]
pub struct Point3D {
pub x: f32,
pub y: f32,
pub z: f32,
pub color: (u8, u8, u8),
pub intensity: f32,
}
#[derive(Debug, Clone)]
pub struct Mesh3D {
pub vertices: Vec<Point3D>,
pub faces: Vec<[usize; 3]>,
pub normals: Vec<(f32, f32, f32)>,
}
#[derive(Debug, Clone)]
pub struct VoxelData {
pub x: f32,
pub y: f32,
pub z: f32,
pub value: f32,
pub color: (u8, u8, u8),
}
impl Visualizer3D {
pub fn new() -> Self {
Self {
background_color: (0.1, 0.1, 0.1),
camera_position: (0.0, 0.0, 5.0),
camera_target: (0.0, 0.0, 0.0),
}
}
pub fn visualize_volume(&self, volume: &Tensor<f32>, threshold: f32) -> Result<String> {
let shape = volume.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (D, H, W), got {}D",
shape.dims().len()
)));
}
let (depth, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mut voxels = Vec::new();
for d in 0..depth {
for h in 0..height {
for w in 0..width {
let value = volume.get(&[d, h, w])?;
if value > threshold {
voxels.push(VoxelData {
x: w as f32,
y: h as f32,
z: d as f32,
value,
color: self.value_to_color(value),
});
}
}
}
}
self.generate_volume_html(&voxels, width as f32, height as f32, depth as f32)
}
pub fn visualize_feature_map(
&self,
feature_map: &Tensor<f32>,
layer_name: &str,
) -> Result<String> {
let shape = feature_map.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
shape.dims().len()
)));
}
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mut points = Vec::new();
let sample_rate = (height * width / 1000).max(1);
for c in (0..channels).step_by(2.max(channels / 10)) {
for h in (0..height).step_by(sample_rate) {
for w in (0..width).step_by(sample_rate) {
let value = feature_map.get(&[c, h, w])?;
if value.abs() > 0.1 {
points.push(Point3D {
x: w as f32 / width as f32,
y: h as f32 / height as f32,
z: c as f32 / channels as f32,
color: self.value_to_color(value),
intensity: value.abs(),
});
}
}
}
}
self.generate_point_cloud_html(&points, layer_name)
}
pub fn generate_mesh_from_mask(&self, mask: &Tensor<f32>) -> Result<Mesh3D> {
let shape = mask.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (D, H, W), got {}D",
shape.dims().len()
)));
}
let (depth, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mut vertices = Vec::new();
let mut faces = Vec::new();
for d in 0..(depth - 1) {
for h in 0..(height - 1) {
for w in 0..(width - 1) {
let cube_values = [
mask.get(&[d, h, w])?,
mask.get(&[d, h, w + 1])?,
mask.get(&[d, h + 1, w + 1])?,
mask.get(&[d, h + 1, w])?,
mask.get(&[d + 1, h, w])?,
mask.get(&[d + 1, h, w + 1])?,
mask.get(&[d + 1, h + 1, w + 1])?,
mask.get(&[d + 1, h + 1, w])?,
];
let threshold = 0.5;
if self.has_surface_intersection(&cube_values, threshold) {
let base_vertex = vertices.len();
vertices.push(Point3D {
x: w as f32,
y: h as f32,
z: d as f32,
color: (255, 255, 255),
intensity: 1.0,
});
vertices.push(Point3D {
x: (w + 1) as f32,
y: h as f32,
z: d as f32,
color: (255, 255, 255),
intensity: 1.0,
});
vertices.push(Point3D {
x: (w + 1) as f32,
y: (h + 1) as f32,
z: d as f32,
color: (255, 255, 255),
intensity: 1.0,
});
vertices.push(Point3D {
x: w as f32,
y: (h + 1) as f32,
z: d as f32,
color: (255, 255, 255),
intensity: 1.0,
});
faces.push([base_vertex, base_vertex + 1, base_vertex + 2]);
faces.push([base_vertex, base_vertex + 2, base_vertex + 3]);
}
}
}
}
let normals = self.calculate_normals(&vertices, &faces);
Ok(Mesh3D {
vertices,
faces,
normals,
})
}
fn has_surface_intersection(&self, values: &[f32; 8], threshold: f32) -> bool {
let above_threshold = values.iter().filter(|&&v| v > threshold).count();
above_threshold > 0 && above_threshold < 8
}
fn calculate_normals(
&self,
vertices: &[Point3D],
faces: &[[usize; 3]],
) -> Vec<(f32, f32, f32)> {
let mut normals = Vec::with_capacity(faces.len());
for face in faces {
let v1 = &vertices[face[0]];
let v2 = &vertices[face[1]];
let v3 = &vertices[face[2]];
let edge1 = (v2.x - v1.x, v2.y - v1.y, v2.z - v1.z);
let edge2 = (v3.x - v1.x, v3.y - v1.y, v3.z - v1.z);
let normal = (
edge1.1 * edge2.2 - edge1.2 * edge2.1,
edge1.2 * edge2.0 - edge1.0 * edge2.2,
edge1.0 * edge2.1 - edge1.1 * edge2.0,
);
let length = (normal.0 * normal.0 + normal.1 * normal.1 + normal.2 * normal.2).sqrt();
if length > 0.0 {
normals.push((normal.0 / length, normal.1 / length, normal.2 / length));
} else {
normals.push((0.0, 1.0, 0.0));
}
}
normals
}
fn value_to_color(&self, value: f32) -> (u8, u8, u8) {
let normalized = (value.abs().clamp(0.0, 1.0) * 255.0) as u8;
if value >= 0.0 {
(normalized, 0, 255 - normalized) } else {
(255 - normalized, normalized, 0) }
}
fn generate_volume_html(
&self,
voxels: &[VoxelData],
width: f32,
height: f32,
depth: f32,
) -> Result<String> {
let html = format!(
r#"
<!DOCTYPE html>
<html>
<head>
<title>ToRSh Vision 3D Volume Viewer</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
<style>
body {{ margin: 0; padding: 0; background: #000; }}
#container {{ width: 100vw; height: 100vh; }}
#info {{ position: absolute; top: 10px; left: 10px; color: white; font-family: Arial; }}
</style>
</head>
<body>
<div id="container"></div>
<div id="info">
<h3>3D Volume Visualization</h3>
<p>Dimensions: {:.0} x {:.0} x {:.0}</p>
<p>Voxels: {}</p>
<p>Mouse: Rotate | Scroll: Zoom</p>
</div>
<script>
// Scene setup
const scene = new THREE.Scene();
const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
const renderer = new THREE.WebGLRenderer();
renderer.setSize(window.innerWidth, window.innerHeight);
renderer.setClearColor(0x{:02x}{:02x}{:02x});
document.getElementById('container').appendChild(renderer.domElement);
// Controls
const controls = new THREE.OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
// Add voxels
const voxelGroup = new THREE.Group();
const voxelGeometry = new THREE.BoxGeometry(1, 1, 1);
// Voxel data would be embedded here in a real implementation
const voxelData = [];
voxelData.forEach(voxel => {{
const material = new THREE.MeshBasicMaterial({{
color: new THREE.Color(`rgb(${{voxel.color[0]}}, ${{voxel.color[1]}}, ${{voxel.color[2]}})`)
}});
const cube = new THREE.Mesh(voxelGeometry, material);
cube.position.set(voxel.x - {:.1}, voxel.y - {:.1}, voxel.z - {:.1});
voxelGroup.add(cube);
}});
scene.add(voxelGroup);
// Lighting
const ambientLight = new THREE.AmbientLight(0x404040, 0.6);
scene.add(ambientLight);
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
directionalLight.position.set(10, 10, 5);
scene.add(directionalLight);
// Camera position
camera.position.set({:.1}, {:.1}, {:.1});
// Animation loop
function animate() {{
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}}
// Handle window resize
window.addEventListener('resize', () => {{
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
}});
animate();
</script>
</body>
</html>
"#,
width,
height,
depth,
voxels.len(),
(self.background_color.0 * 255.0) as u8,
(self.background_color.1 * 255.0) as u8,
(self.background_color.2 * 255.0) as u8,
width / 2.0,
height / 2.0,
depth / 2.0,
self.camera_position.0,
self.camera_position.1,
self.camera_position.2
);
Ok(html)
}
fn generate_point_cloud_html(&self, points: &[Point3D], layer_name: &str) -> Result<String> {
let points_json: Vec<String> = points
.iter()
.map(|p| {
format!(
r#"{{"x": {}, "y": {}, "z": {}, "color": [{}, {}, {}], "intensity": {}}}"#,
p.x, p.y, p.z, p.color.0, p.color.1, p.color.2, p.intensity
)
})
.collect();
let html = format!(
r#"
<!DOCTYPE html>
<html>
<head>
<title>ToRSh Vision 3D Feature Map: {}</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
<style>
body {{ margin: 0; padding: 0; background: #000; }}
#container {{ width: 100vw; height: 100vh; }}
#info {{ position: absolute; top: 10px; left: 10px; color: white; font-family: Arial; }}
</style>
</head>
<body>
<div id="container"></div>
<div id="info">
<h3>3D Feature Map: {}</h3>
<p>Points: {}</p>
<p>Mouse: Rotate | Scroll: Zoom</p>
</div>
<script>
// Scene setup
const scene = new THREE.Scene();
const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
const renderer = new THREE.WebGLRenderer();
renderer.setSize(window.innerWidth, window.innerHeight);
renderer.setClearColor(0x{:02x}{:02x}{:02x});
document.getElementById('container').appendChild(renderer.domElement);
// Controls
const controls = new THREE.OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
// Point cloud
const pointsData = [{}];
const geometry = new THREE.BufferGeometry();
const positions = [];
const colors = [];
pointsData.forEach(point => {{
positions.push(point.x, point.y, point.z);
colors.push(point.color[0] / 255, point.color[1] / 255, point.color[2] / 255);
}});
geometry.setAttribute('position', new THREE.Float32BufferAttribute(positions, 3));
geometry.setAttribute('color', new THREE.Float32BufferAttribute(colors, 3));
const material = new THREE.PointsMaterial({{
size: 0.01,
vertexColors: true,
transparent: true,
opacity: 0.8
}});
const pointCloud = new THREE.Points(geometry, material);
scene.add(pointCloud);
// Lighting
const ambientLight = new THREE.AmbientLight(0x404040, 0.6);
scene.add(ambientLight);
// Camera position
camera.position.set(1.5, 1.5, 1.5);
// Animation loop
function animate() {{
requestAnimationFrame(animate);
controls.update();
pointCloud.rotation.y += 0.005;
renderer.render(scene, camera);
}}
// Handle window resize
window.addEventListener('resize', () => {{
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
}});
animate();
</script>
</body>
</html>
"#,
layer_name,
layer_name,
points.len(),
(self.background_color.0 * 255.0) as u8,
(self.background_color.1 * 255.0) as u8,
(self.background_color.2 * 255.0) as u8,
points_json.join(", ")
);
Ok(html)
}
pub fn save_visualization<P: AsRef<Path>>(&self, html_content: &str, path: P) -> Result<()> {
std::fs::write(path, html_content)?;
Ok(())
}
}
impl Default for Visualizer3D {
fn default() -> Self {
Self::new()
}
}
pub fn create_3d_visualizer() -> Visualizer3D {
Visualizer3D::new()
}
pub fn visualize_activations_3d(
activations: &[Tensor<f32>],
layer_names: &[String],
output_dir: &str,
) -> Result<()> {
let visualizer = Visualizer3D::new();
std::fs::create_dir_all(output_dir)?;
for (i, (activation, name)) in activations.iter().zip(layer_names.iter()).enumerate() {
let html = visualizer.visualize_feature_map(activation, name)?;
let filename = format!("{}/layer_{:02}_{}.html", output_dir, i, name);
visualizer.save_visualization(&html, filename)?;
}
let index_html = create_activation_index(layer_names);
let index_path = format!("{}/index.html", output_dir);
std::fs::write(index_path, index_html)?;
println!("3D visualizations saved to: {}", output_dir);
Ok(())
}
fn create_activation_index(layer_names: &[String]) -> String {
let mut links = String::new();
for (i, name) in layer_names.iter().enumerate() {
links.push_str(&format!(
r#"<li><a href="layer_{:02}_{}.html">{}</a></li>"#,
i, name, name
));
}
format!(
r#"
<!DOCTYPE html>
<html>
<head>
<title>ToRSh Vision 3D Activations Index</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
ul {{ list-style-type: none; padding: 0; }}
li {{ margin: 10px 0; }}
a {{ display: block; padding: 10px; background: #f0f0f0; text-decoration: none; border-radius: 4px; }}
a:hover {{ background: #e0e0e0; }}
</style>
</head>
<body>
<h1>ToRSh Vision 3D Layer Activations</h1>
<p>Click on any layer to view its 3D activation visualization:</p>
<ul>
{}
</ul>
</body>
</html>
"#,
links
)
}