smpl-utils 0.9.0

Small library for utility functions related to smpl-rs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
use burn::tensor::{backend::Backend, Float, Int, Tensor};
use gloss_utils::nshare::{RefNdarray2, ToNalgebra};
use nalgebra as na;
use nalgebra::clamp;
use ndarray as nd;
use ndarray::prelude::*;
use std::{
    f32::consts::PI,
    ops::{Div, SubAssign},
};
pub fn hex_to_rgb(hex: &str) -> (u8, u8, u8) {
    let hex = hex.trim_start_matches('#');
    let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(0);
    let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(0);
    let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(0);
    (r, g, b)
}
pub fn hex_to_rgb_f32(hex: &str) -> (f32, f32, f32) {
    let (r, g, b) = hex_to_rgb(hex);
    (f32::from(r) / 255.0, f32::from(g) / 255.0, f32::from(b) / 255.0)
}
pub fn interpolate_angle(cur_angle: f32, other_angle: f32, _cur_w: f32, other_w: f32) -> f32 {
    let mut diff = other_angle - cur_angle;
    if diff.abs() > PI {
        if diff > 0.0 {
            diff -= 2.0 * PI;
        } else {
            diff += 2.0 * PI;
        }
    }
    cur_angle + other_w * diff
}
pub fn interpolate_angle_tensor<B: Backend>(cur_angle: Tensor<B, 1>, other_angle: Tensor<B, 1>, _cur_w: f32, other_w: f32) -> Tensor<B, 1> {
    let mut diff = other_angle - cur_angle.clone();
    assert!(cur_angle.dims() == [1]);
    let abs_diff = diff.clone().abs();
    let needs_adjustment = abs_diff.greater_elem(PI);
    let two_pi = Tensor::<B, 1>::from_floats([2.0 * PI], &cur_angle.device());
    let neg_two_pi = Tensor::<B, 1>::from_floats([-2.0 * PI], &cur_angle.device());
    let positive_mask = diff.clone().greater_elem(0.0);
    let negative_mask = diff.clone().lower_elem(0.0);
    let pos_adjustment = positive_mask.clone().float() * neg_two_pi.clone();
    let neg_adjustment = negative_mask.clone().float() * two_pi.clone();
    let total_adjustment = pos_adjustment + neg_adjustment;
    let adjustment = needs_adjustment.float() * total_adjustment;
    diff = diff + adjustment;
    cur_angle + other_w * diff
}
pub fn axis_angle_to_quaternion<B: Backend>(axis_angle: Tensor<B, 2>) -> Tensor<B, 2> {
    let eps = 1e-6f32;
    let angle: Tensor<B, 1> = axis_angle.clone().powf_scalar(2.0).sum_dim(1).squeeze_dims(&[1]).clamp_min(1e-8).sqrt();
    let denom = angle.clone().unsqueeze_dim(1) + eps;
    let axis = axis_angle / denom;
    let half_angle = angle * 0.5;
    let cos_half = half_angle.clone().cos();
    let sin_half = half_angle.sin();
    let qxyz = axis.clone().slice([0..axis.dims()[0], 0..3]) * sin_half.clone().unsqueeze_dim(1);
    let qw = cos_half.unsqueeze_dim(1);
    Tensor::cat(vec![qxyz, qw], 1)
}
pub fn quaternion_to_axis_angle<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
    let eps = 1e-6f32;
    let nr_rows = quat.dims()[0];
    let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
    let qw = quat.slice([0..nr_rows, 3..4]).squeeze(1);
    let vec_norm = qxyz.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt().squeeze_dims(&[1]);
    let abs_qw = qw.abs();
    let safe_qw = abs_qw.clone() + eps;
    let half_angle_tan = vec_norm.clone() / safe_qw;
    let small_rotation_mask = abs_qw.greater_elem(0.9);
    let small_angle_approx = 2.0 * vec_norm.clone();
    let x = half_angle_tan.clone();
    let atan_approx = x.clone() / (1.0 + 0.28 * x.powf_scalar(2.0));
    let large_angle_approx = 2.0 * atan_approx;
    let small_mask_float = small_rotation_mask.clone().float();
    let angle: Tensor<B, 1> = small_mask_float.clone() * small_angle_approx + (1.0 - small_mask_float) * large_angle_approx;
    let small_angle_mask = vec_norm.clone().lower_elem(eps);
    let safe_vec_norm = vec_norm.clone() + eps;
    let angle_over_norm = angle.unsqueeze_dim(1) / safe_vec_norm.unsqueeze_dim(1);
    let axis_angle = qxyz * angle_over_norm;
    let small_angle_mask_3d = small_angle_mask.float().unsqueeze_dim(1);
    (1.0 - small_angle_mask_3d) * axis_angle
}
pub fn quaternion_to_axis_angle_fast<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
    let eps = 1e-6f32;
    let nr_rows = quat.dims()[0];
    let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
    let qw: Tensor<B, 1> = quat.slice([0..nr_rows, 3..4]).squeeze(1);
    let w_negative_mask = qw.clone().lower_elem(0.0);
    let qw: Tensor<B, 1> = w_negative_mask.clone().float() * (-qw.clone()) + (1.0 - w_negative_mask.clone().float()) * qw.clone();
    let qxyz = w_negative_mask.clone().float().unsqueeze_dim(1) * (-qxyz.clone()) + (1.0 - w_negative_mask.float().unsqueeze_dim(1)) * qxyz.clone();
    let clamped_w = qw.clone().clamp(0.0, 1.0);
    let one_minus_w: Tensor<B, 1> = 1.0 - clamped_w.clone();
    let sqrt_term = one_minus_w.sqrt();
    let acos_w =
        sqrt_term * (1.570_728_8 + clamped_w.clone() * (-0.212_114_4 + clamped_w.clone() * (0.074_261_0 + clamped_w.clone() * -0.018_729_3)));
    let angle: Tensor<B, 1> = 2.0 * acos_w;
    let one_minus_square: Tensor<B, 1> = 1.0 - clamped_w.clone() * clamped_w;
    let sin_half_angle: Tensor<B, 1> = one_minus_square.sqrt();
    let denom = sin_half_angle + eps;
    let axis = qxyz / denom.unsqueeze_dim(1);
    axis * angle.unsqueeze_dim(1)
}
pub fn quaternion_interpolate_slerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
    let eps = 1e-6f32;
    let lhs_norm = lhs.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt() + eps;
    let other_norm = other.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt() + eps;
    let lhs_normalized = lhs / lhs_norm;
    let other_normalized = other / other_norm;
    let dot: Tensor<B, 1> = (lhs_normalized.clone() * other_normalized.clone()).sum_dim(1).squeeze_dims(&[1]);
    let negative_dot_mask = dot.clone().lower_elem(0.0);
    let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.clone().float();
    let dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float;
    let sign_corrected_other =
        negative_dot_mask.clone().float().unsqueeze_dim(1) * (-other_normalized.clone()) + dot_mask_float.unsqueeze_dim(1) * other_normalized.clone();
    let corrected_dot: Tensor<B, 1> = dot.clone().abs();
    let close_threshold = 0.9995f32;
    let very_close_mask = corrected_dot.clone().greater_elem(close_threshold);
    let lerp_result = lhs_normalized.clone() * (1.0 - other_weight) + sign_corrected_other.clone() * other_weight;
    let lerp_norm = lerp_result.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt() + eps;
    let lerp_normalized = lerp_result / lerp_norm;
    let one_minus_dot_sq: Tensor<B, 1> = 1.0 - corrected_dot.clone().powf_scalar(2.0);
    let sqrt_term = one_minus_dot_sq.clamp_min(1e-8).sqrt();
    let safe_dot = corrected_dot.clone() + eps;
    let ratio = sqrt_term / safe_dot;
    let theta_approx: Tensor<B, 1> = ratio.clone() / (1.0 + 0.28 * ratio.clone().powf_scalar(2.0));
    let sin_theta = theta_approx.clone().sin();
    let safe_sin_theta = sin_theta.clone() + eps;
    let weight_lhs = ((1.0 - other_weight) * theta_approx.clone()).sin() / safe_sin_theta.clone();
    let weight_other = (other_weight * theta_approx).sin() / safe_sin_theta;
    let slerp_result = lhs_normalized.clone() * weight_lhs.unsqueeze_dim(1) + sign_corrected_other * weight_other.unsqueeze_dim(1);
    let inv_very_close_mask_float: Tensor<B, 1> = 1.0 - very_close_mask.clone().float();
    very_close_mask.clone().float().unsqueeze_dim(1) * lerp_normalized + inv_very_close_mask_float.unsqueeze_dim(1) * slerp_result
}
pub fn quaternion_interpolate_lerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
    let eps = 1e-6f32;
    let dot: Tensor<B, 1> = (lhs.clone() * other.clone()).sum_dim(1).squeeze_dims(&[1]);
    let negative_dot_mask = dot.lower_elem(0.0);
    let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.float();
    let positive_dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float.clone();
    let sign_corrected_other = negative_dot_mask_float.clone().unsqueeze_dim(1) * (-other.clone()) + positive_dot_mask_float.unsqueeze_dim(1) * other;
    let lerp_result = lhs * (1.0 - other_weight) + sign_corrected_other * other_weight;
    let lerp_norm_sq = lerp_result.clone().powf_scalar(2.0).sum_dim(1);
    lerp_result / (lerp_norm_sq.clamp_min(1e-8).sqrt() + eps)
}
pub fn map(value: f32, in_min: f32, in_max: f32, out_min: f32, out_max: f32) -> f32 {
    let value_clamped = clamp(value, in_min, in_max);
    out_min + (out_max - out_min) * (value_clamped - in_min) / (in_max - in_min)
}
pub fn smootherstep(low: f32, high: f32, val: f32) -> f32 {
    let t = map(val, low, high, 0.0, 1.0);
    t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
}
pub fn batch_rodrigues(full_pose: &nd::Array2<f32>) -> nd::Array3<f32> {
    let mut rotations_per_join = ndarray::Array3::<f32>::zeros((full_pose.shape()[0], 3, 3));
    for (idx, v) in full_pose.axis_iter(nd::Axis(0)).enumerate() {
        let angle = v.iter().map(|x| x * x).sum::<f32>().sqrt();
        let rot_dir = full_pose.row(idx).to_owned().div(angle + 1e-6);
        let cos = angle.cos();
        let sin = angle.sin();
        let (rx, ry, rz) = (rot_dir[0], rot_dir[1], rot_dir[2]);
        let k = array![[0.0, -rz, ry], [rz, 0.0, -rx], [-ry, rx, 0.0]];
        let identity = ndarray::Array2::<f32>::eye(3);
        let rot_mat = identity + sin * k.clone() + (1.0 - cos) * k.dot(&k);
        rotations_per_join.slice_mut(s![idx, .., ..]).assign(&rot_mat);
    }
    rotations_per_join
}
#[allow(clippy::let_and_return)]
pub fn batch_rodrigues_burn<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
    let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
    let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt();
    let denom = angle.clone() + eps.unsqueeze_dim(0);
    let k = full_pose.clone() / denom;
    let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
    let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
    let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
    let zero: Tensor<B, 2> = Tensor::<B, 1, Float>::zeros_like(&kx).unsqueeze_dim(1);
    let k11 = zero.clone();
    let k12 = -kz.clone().unsqueeze_dim(1);
    let k13 = ky.clone().unsqueeze_dim(1);
    let k21 = kz.clone().unsqueeze_dim(1);
    let k22 = zero.clone();
    let k23 = -kx.clone().unsqueeze_dim(1);
    let k31 = -ky.clone().unsqueeze_dim(1);
    let k32 = kx.clone().unsqueeze_dim(1);
    let k33 = zero;
    let k_mat = Tensor::cat(
        vec![
            Tensor::cat(vec![k11, k12, k13], 1),
            Tensor::cat(vec![k21, k22, k23], 1),
            Tensor::cat(vec![k31, k32, k33], 1),
        ],
        1,
    )
    .reshape([-1, 3, 3]);
    let cos = angle.clone().cos().unsqueeze_dim(2);
    let sin = angle.clone().sin().unsqueeze_dim(2);
    let eye = Tensor::<B, 2, Float>::eye(3, &full_pose.device()).unsqueeze_dim(0);
    let eye = eye.repeat(&[full_pose.dims()[0], 1, 1]);
    let k_sq = k_mat.clone().matmul(k_mat.clone());
    let rot_mat = eye + sin * k_mat + (Tensor::ones_like(&cos) - cos) * k_sq;
    rot_mat
}
#[allow(clippy::let_and_return)]
pub fn batch_rodrigues_burn_2<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
    let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
    let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).clamp_min(1e-8).sqrt().squeeze(1);
    let denom = angle.clone().unsqueeze_dim(1) + eps.unsqueeze_dim(0);
    let k = full_pose.clone() / denom;
    let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
    let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
    let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
    let cos = angle.clone().cos();
    let sin = angle.clone().sin();
    let one = Tensor::<B, 1, Float>::ones_like(&cos);
    let one_minus_cos = one.clone() - cos.clone();
    let r11 = cos.clone() + one_minus_cos.clone() * kx.clone() * kx.clone();
    let r12 = one_minus_cos.clone() * kx.clone() * ky.clone() - sin.clone() * kz.clone();
    let r13 = one_minus_cos.clone() * kx.clone() * kz.clone() + sin.clone() * ky.clone();
    let r21 = one_minus_cos.clone() * ky.clone() * kx.clone() + sin.clone() * kz.clone();
    let r22 = cos.clone() + one_minus_cos.clone() * ky.clone() * ky.clone();
    let r23 = one_minus_cos.clone() * ky.clone() * kz.clone() - sin.clone() * kx.clone();
    let r31 = one_minus_cos.clone() * kz.clone() * kx.clone() - sin.clone() * ky.clone();
    let r32 = one_minus_cos.clone() * kz.clone() * ky.clone() + sin.clone() * kx.clone();
    let r33 = cos.clone() + one_minus_cos.clone() * kz.clone() * kz.clone();
    let rot_mat = Tensor::stack(
        vec![
            Tensor::stack::<2>(vec![r11, r12, r13], 1),
            Tensor::stack::<2>(vec![r21, r22, r23], 1),
            Tensor::stack::<2>(vec![r31, r32, r33], 1),
        ],
        1,
    );
    rot_mat
}
#[allow(clippy::let_and_return)]
pub fn batch_rodrigues_burn_3<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
    let device = full_pose.device();
    let angle: Tensor<B, 1> = full_pose.clone().powi_scalar(2).sum_dim(1).clamp_min(1e-8).sqrt().squeeze(1);
    let denom = angle.clone().unsqueeze_dim(1) + 1e-6;
    let k = full_pose.clone() / denom;
    let k_3_1 = k.clone().unsqueeze_dim(2);
    let k_1_3 = k.clone().unsqueeze_dim(1);
    let kk_t = k_3_1 * k_1_3;
    let kx = k.clone().slice_dim(1, 0..1).squeeze(1);
    let ky = k.clone().slice_dim(1, 1..2).squeeze(1);
    let kz = k.clone().slice_dim(1, 2..3).squeeze(1);
    let zero = Tensor::<B, 1, Float>::zeros_like(&kx);
    let row1 = Tensor::stack::<2>(vec![zero.clone(), -kz.clone(), ky.clone()], 1);
    let row2 = Tensor::stack::<2>(vec![kz.clone(), zero.clone(), -kx.clone()], 1);
    let row3 = Tensor::stack::<2>(vec![-ky.clone(), kx.clone(), zero.clone()], 1);
    let k = Tensor::stack(vec![row1, row2, row3], 1);
    let cos: Tensor<B, 3> = angle.clone().cos().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
    let sin: Tensor<B, 3> = angle.clone().sin().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
    let one_minus_cos = 1.0 - cos.clone();
    let eye = Tensor::<B, 2, Float>::eye(3, &device).unsqueeze_dim(0);
    let rot = cos * eye + one_minus_cos * kk_t + sin * k;
    rot
}
pub fn euler2angleaxis(euler_x: f32, euler_y: f32, euler_z: f32) -> na::Vector3<f32> {
    let c1 = f32::cos(euler_x / 2.0);
    let c2 = f32::cos(euler_y / 2.0);
    let c3 = f32::cos(euler_z / 2.0);
    let s1: f32 = f32::sin(euler_x / 2.0);
    let s2 = f32::sin(euler_y / 2.0);
    let s3 = f32::sin(euler_z / 2.0);
    let rot = na::Quaternion::new(
        c1 * c2 * c3 - s1 * s2 * s3,
        s1 * c2 * c3 + c1 * s2 * s3,
        c1 * s2 * c3 - s1 * c2 * s3,
        c1 * c2 * s3 + s1 * s2 * c3,
    );
    let rot = na::UnitQuaternion::new_normalize(rot);
    rot.scaled_axis()
}
/// Interpolates between two axis angles using a slerp
pub fn interpolate_axis_angle(this_axis: &nd::Array1<f32>, other_axis: &nd::Array1<f32>, other_weight: f32) -> nd::Array1<f32> {
    let this_axis_na = this_axis.clone().into_nalgebra();
    let other_axis_na = other_axis.clone().into_nalgebra();
    let cur_r = na::Rotation3::new(this_axis_na.fixed_rows(0));
    let other_r = na::Rotation3::new(other_axis_na.fixed_rows(0));
    let new_r = cur_r.slerp(&other_r, other_weight);
    let axis_angle = new_r.scaled_axis();
    let new_axis_angle_nd = array![axis_angle.x, axis_angle.y, axis_angle.z];
    new_axis_angle_nd
}
/// Interpolates betwen batch of axis angles where the batch is shape
/// [``nr_joints``, 3]
pub fn interpolate_axis_angle_batch(this_axis: &nd::Array2<f32>, other_axis: &nd::Array2<f32>, other_weight: f32) -> nd::Array2<f32> {
    let this_axis_na = this_axis.clone().into_nalgebra();
    let other_axis_na = other_axis.clone().into_nalgebra();
    let mut new_axis_angles = nd::Array2::<f32>::zeros(this_axis_na.shape());
    for ((this_axis, other_axis), mut new_joint) in this_axis_na
        .row_iter()
        .zip(other_axis_na.row_iter())
        .zip(new_axis_angles.axis_iter_mut(nd::Axis(0)))
    {
        let cur_r = na::Rotation3::new(this_axis.transpose().fixed_rows(0));
        let other_r = na::Rotation3::new(other_axis.transpose().fixed_rows(0));
        let new_r = cur_r.slerp(&other_r, other_weight);
        let axis_angle = new_r.scaled_axis();
        new_joint.assign(&array![axis_angle.x, axis_angle.y, axis_angle.z]);
    }
    new_axis_angles
}
#[allow(clippy::missing_panics_doc)]
#[allow(clippy::similar_names)]
#[allow(clippy::cast_sign_loss)]
pub fn batch_rigid_transform(
    parent_idx_per_joint: Vec<u32>,
    rot_mats: &nd::Array3<f32>,
    joints: &nd::Array2<f32>,
    num_joints: usize,
) -> (nd::Array2<f32>, nd::Array3<f32>) {
    let mut rel_joints = joints.clone();
    let parent_idx_data_u32 = parent_idx_per_joint;
    let parent_idx_per_joint = nd::Array1::from_vec(parent_idx_data_u32);
    for (idx_cur, idx_parent) in parent_idx_per_joint.iter().enumerate().skip(1) {
        let parent_joint_position = joints.row(*idx_parent as usize);
        rel_joints.row_mut(idx_cur).sub_assign(&parent_joint_position);
    }
    let mut transforms_mat = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
    for idx in 0..=num_joints {
        let rot = rot_mats.slice(s![idx, .., ..]).to_owned();
        let t = rel_joints.row(idx).to_owned();
        transforms_mat.slice_mut(s![idx, 0..3, 0..3]).assign(&rot);
        transforms_mat.slice_mut(s![idx, 0..3, 3]).assign(&t);
        transforms_mat.slice_mut(s![idx, 3, 0..4]).assign(&array![0.0, 0.0, 0.0, 1.0]);
    }
    let mut transform_chain = Vec::new();
    transform_chain.push(transforms_mat.slice(s![0, 0..4, 0..4]).to_owned().into_shape((4, 4)).unwrap());
    for i in 1..=num_joints {
        let mat_1 = &transform_chain[parent_idx_per_joint[[i]] as usize];
        let mat_2 = transforms_mat.slice(s![i, 0..4, 0..4]);
        let curr_res = mat_1.dot(&mat_2);
        transform_chain.push(curr_res);
    }
    let mut posed_joints = joints.clone();
    for (i, tf) in transform_chain.iter().enumerate() {
        let t = tf.slice(s![0..3, 3]);
        posed_joints.row_mut(i).assign(&t);
    }
    let mut rel_transforms = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
    for (i, transform) in transform_chain.iter().enumerate() {
        let (jx, jy, jz) = (joints.row(i)[0], joints.row(i)[1], joints.row(i)[2]);
        let joint_homogen = array![jx, jy, jz, 0.0];
        let transformed_joint = transform.dot(&joint_homogen);
        let mut transformed_joint_4 = nd::Array2::<f32>::zeros((4, 4));
        transformed_joint_4.slice_mut(s![0..4, 3]).assign(&transformed_joint);
        transformed_joint_4 = transform - transformed_joint_4;
        rel_transforms.slice_mut(s![i, .., ..]).assign(&transformed_joint_4);
    }
    (posed_joints, rel_transforms)
}
/// Burn-only batch rigid transform
pub fn batch_rigid_transform_burn<B: Backend>(
    parent_idx_per_joint_t: Tensor<B, 1, Int>,
    parent_idx_per_joint: &nd::Array1<u32>,
    rot_mats: Tensor<B, 3>,
    joints: Tensor<B, 2>,
) -> (Tensor<B, 2>, Tensor<B, 3>) {
    let num_joints = joints.dims()[0];
    let parent_idx_per_joint_t = parent_idx_per_joint_t.slice_fill(0..1, 0);
    let parent_joints = joints.clone().select(0, parent_idx_per_joint_t);
    let rel_joints = joints.clone() - parent_joints;
    let rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
    let eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
    let eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
    let t_col = rel_joints.reshape([num_joints, 3, 1]);
    let upper = Tensor::cat(vec![rot_mats, t_col], 2);
    let transforms = Tensor::cat(vec![upper, eye_row], 1);
    let mut transform_chain: Vec<Tensor<B, 2>> = Vec::new();
    #[allow(clippy::needless_range_loop)]
    #[allow(clippy::single_range_in_vec_init)]
    #[allow(clippy::range_plus_one)]
    for j in 0..num_joints {
        let parent = parent_idx_per_joint[j] as usize;
        let t_j = transforms.clone().slice([j..j + 1]);
        let t_j = t_j.squeeze(0);
        if j == 0 {
            transform_chain.push(t_j);
        } else {
            let parent_t = transform_chain[parent].clone();
            transform_chain.push(parent_t.matmul(t_j));
        }
    }
    let transform_chain = Tensor::stack(transform_chain, 0);
    let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
    let joints_homo = joints.pad((0, 1, 0, 0), 0.0);
    let joints_homo = joints_homo.unsqueeze_dim(2);
    let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
    let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
    transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
    let rel_transforms = transform_chain - transformed_joint_4;
    (posed_joints, rel_transforms)
}
/// Faster Burn implementation of `batch_rigid_transform` (single-skeleton version)
/// - `parent_idx_per_joint_t`: Tensor<Int> shape [J] (index tensor)
/// - `parent_idx_per_joint`: `ndarray::Array1`<u32> (cpu-side parent indices, same shape)
/// - `rot_mats`: Tensor shape [J,3,3]
/// - joints: Tensor shape [J,3]
///   instead of doing a sequential loop over the 55 joints to accumulate the transforms, we do log J iterations.
///   Assume the tree is like
///   root (0)
///   |
///   1
///   |
///   2
///   |
///   3
///   On the first iteration, each joint knows about its local transform
///   chain[0] = L[0]    (root, special case)
///   chain[1] = L[1]
///   chain[2] = L[2]
///   chain[3] = L[3]
///   Then we accumulate the transform to the parent
///   chain[j] = chain[parent[j]] · chain[j]
///   root (0)        chain[0] = L[0]
///   1               chain[1] = L[0]·L[1]
///   2               chain[2] = L[1]·L[2]
///   3               chain[3] = L[2]·L[3]
///   Then we accumulate the transform to the grandparent
///   chain[j] = chain[parent^2[j]] · chain[j]
///   root (0)        chain[0] = L[0]
///   1               chain[1] = L[0]·L[1]
///   2               chain[2] = L[0]·L[1]·L[2]
///   3               chain[3] = L[0]·L[1]·L[2]·L[3]
pub fn batch_rigid_transform_burn_fast<B: Backend>(
    mut parent_idx_per_joint_t: Tensor<B, 1, Int>,
    _parent_idx_per_joint: &nd::Array1<u32>,
    rot_mats: Tensor<B, 3>,
    joints: Tensor<B, 2>,
    kinematic_tree_depth: usize,
) -> (Tensor<B, 2>, Tensor<B, 3>) {
    let num_joints = joints.dims()[0];
    parent_idx_per_joint_t = parent_idx_per_joint_t.clone().slice_fill(0..1, 0);
    let parent_joints = joints.clone().select(0, parent_idx_per_joint_t.clone());
    let mut rel_joints = joints.clone() - parent_joints;
    rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
    let t_col = rel_joints.reshape([num_joints, 3, 1]);
    let upper = Tensor::cat(vec![rot_mats, t_col], 2);
    let mut eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
    eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
    let transforms = Tensor::cat(vec![upper, eye_row], 1);
    let mut transform_chain = transforms.clone();
    let identity = Tensor::eye(4, &joints.device()).unsqueeze_dim(0);
    transform_chain = transform_chain.slice_assign([0..1, 0..4, 0..4], identity.clone());
    let mut parent_pow = parent_idx_per_joint_t.clone();
    #[allow(clippy::cast_possible_truncation)]
    #[allow(clippy::cast_sign_loss)]
    #[allow(clippy::cast_precision_loss)]
    let max_steps = if num_joints <= 1 {
        0usize
    } else {
        (kinematic_tree_depth as f32).log2().ceil() as usize
    };
    for _ in 0..max_steps {
        let parent_transforms = transform_chain.clone().select(0, parent_pow.clone());
        let new_chain = parent_transforms.matmul(transform_chain.clone());
        parent_pow = parent_pow.clone().select(0, parent_pow.clone());
        transform_chain = new_chain;
    }
    let root_transform = transforms.clone().slice([0..1, 0..4, 0..4]);
    let transform_chain = root_transform.matmul(transform_chain);
    let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
    let joints_homo = joints.pad((0, 1, 0, 0), 0.0).unsqueeze_dim(2);
    let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
    let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
    transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
    let rel_transforms = transform_chain - transformed_joint_4;
    (posed_joints, rel_transforms)
}
/// Converts a 2D array of quaternions of shape Nx4 (each row being a quaternion in format xyzw) and a 2D array of translations of shape Nx3 to extrinsics as an 3D array of Nx4x4
pub fn extract_extrinsics_from_rot_trans(translations: &ndarray::Array2<f32>, rotations: &ndarray::Array2<f32>) -> ndarray::Array3<f32> {
    let num_frames = translations.shape()[0].min(rotations.shape()[0]);
    let mut extrinsics = ndarray::Array3::<f32>::zeros((num_frames, 4, 4));
    for frame in 0..num_frames {
        let trans = nalgebra::Vector3::new(translations[(frame, 0)], translations[(frame, 1)], translations[(frame, 2)]);
        let quat = nalgebra::UnitQuaternion::new_normalize(nalgebra::Quaternion::new(
            rotations[(frame, 3)],
            rotations[(frame, 0)],
            rotations[(frame, 1)],
            rotations[(frame, 2)],
        ));
        let transform = nalgebra::Isometry3::from_parts(trans.into(), quat);
        let matrix_4x4 = transform.to_homogeneous();
        extrinsics.slice_mut(s![frame, .., ..]).assign(&matrix_4x4.ref_ndarray2());
    }
    extrinsics
}
pub fn compute_tree_depth(parent_idx_per_joint: &nd::Array1<u32>) -> usize {
    let mut max_depth = 0;
    for i in 0..parent_idx_per_joint.len() {
        let mut depth = 0;
        let mut current_idx = i;
        loop {
            let parent = parent_idx_per_joint[current_idx];
            if parent == 0 || parent >= u32::try_from(parent_idx_per_joint.len()).unwrap() {
                depth += 1;
                break;
            }
            depth += 1;
            current_idx = parent as usize;
        }
        max_depth = max_depth.max(depth);
    }
    max_depth
}