Skip to main content

mlx_native/ops/
log_elementwise.rs

1//! Elementwise natural log forward + backward.
2//!
3//! Used by reverse-mode autograd in downstream crates (hf2q ADR-020
4//! Track 1: log_softmax + KL-div composition).
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14pub static LOG_SHADER_SOURCE: &str = include_str!("../shaders/log_elementwise.metal");
15
16pub fn register(registry: &mut KernelRegistry) {
17    registry.register_source("log_f32", LOG_SHADER_SOURCE);
18    registry.register_source("log_backward_f32", LOG_SHADER_SOURCE);
19}
20
21/// Encode `output[i] = log(input[i])` for f32 input.
22///
23/// Caller must ensure `input` is strictly positive — the kernel does
24/// not check; `log(x ≤ 0)` produces NaN or `-inf` per IEEE 754.
25pub fn dispatch_log_f32(
26    encoder: &mut CommandEncoder,
27    registry: &mut KernelRegistry,
28    device: &metal::DeviceRef,
29    input: &MlxBuffer,
30    output: &MlxBuffer,
31) -> Result<()> {
32    let n = input.element_count();
33    if n == 0 {
34        return Err(MlxError::InvalidArgument(
35            "log_f32: input must have at least one element".into(),
36        ));
37    }
38    if output.element_count() != n {
39        return Err(MlxError::InvalidArgument(format!(
40            "log_f32: output element count {} != input element count {}",
41            output.element_count(),
42            n
43        )));
44    }
45    if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
46        return Err(MlxError::InvalidArgument(format!(
47            "log_f32: only f32 supported; got input={} output={}",
48            input.dtype(),
49            output.dtype()
50        )));
51    }
52
53    let pipeline = registry.get_pipeline("log_f32", device)?;
54    let thread_count = n as u64;
55    let threadgroup_size = std::cmp::min(256, thread_count);
56
57    encoder.encode(
58        pipeline,
59        &[(0, input), (1, output)],
60        MTLSize::new(thread_count, 1, 1),
61        MTLSize::new(threadgroup_size, 1, 1),
62    );
63
64    Ok(())
65}
66
67/// Encode `dx[i] = dy[i] / x[i]` (the backward pass for elementwise
68/// log).  `x` is the FORWARD INPUT, not the forward output.
69pub fn dispatch_log_backward_f32(
70    encoder: &mut CommandEncoder,
71    registry: &mut KernelRegistry,
72    device: &metal::DeviceRef,
73    x: &MlxBuffer,
74    dy: &MlxBuffer,
75    dx: &MlxBuffer,
76) -> Result<()> {
77    let n = x.element_count();
78    if n == 0 {
79        return Err(MlxError::InvalidArgument(
80            "log_backward_f32: x must have at least one element".into(),
81        ));
82    }
83    for (label, buf) in [("dy", dy), ("dx", dx)] {
84        if buf.element_count() != n {
85            return Err(MlxError::InvalidArgument(format!(
86                "log_backward_f32: {label} element count {} != x element count {}",
87                buf.element_count(),
88                n
89            )));
90        }
91    }
92    for (label, buf) in [("x", x), ("dy", dy), ("dx", dx)] {
93        if buf.dtype() != DType::F32 {
94            return Err(MlxError::InvalidArgument(format!(
95                "log_backward_f32: {label} dtype {} not f32",
96                buf.dtype()
97            )));
98        }
99    }
100
101    let pipeline = registry.get_pipeline("log_backward_f32", device)?;
102    let thread_count = n as u64;
103    let threadgroup_size = std::cmp::min(256, thread_count);
104
105    encoder.encode(
106        pipeline,
107        &[(0, x), (1, dy), (2, dx)],
108        MTLSize::new(thread_count, 1, 1),
109        MTLSize::new(threadgroup_size, 1, 1),
110    );
111
112    Ok(())
113}