mlx_native/ops/
log_elementwise.rs1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14pub static LOG_SHADER_SOURCE: &str = include_str!("../shaders/log_elementwise.metal");
15
16pub fn register(registry: &mut KernelRegistry) {
17 registry.register_source("log_f32", LOG_SHADER_SOURCE);
18 registry.register_source("log_backward_f32", LOG_SHADER_SOURCE);
19}
20
21pub fn dispatch_log_f32(
26 encoder: &mut CommandEncoder,
27 registry: &mut KernelRegistry,
28 device: &metal::DeviceRef,
29 input: &MlxBuffer,
30 output: &MlxBuffer,
31) -> Result<()> {
32 let n = input.element_count();
33 if n == 0 {
34 return Err(MlxError::InvalidArgument(
35 "log_f32: input must have at least one element".into(),
36 ));
37 }
38 if output.element_count() != n {
39 return Err(MlxError::InvalidArgument(format!(
40 "log_f32: output element count {} != input element count {}",
41 output.element_count(),
42 n
43 )));
44 }
45 if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
46 return Err(MlxError::InvalidArgument(format!(
47 "log_f32: only f32 supported; got input={} output={}",
48 input.dtype(),
49 output.dtype()
50 )));
51 }
52
53 let pipeline = registry.get_pipeline("log_f32", device)?;
54 let thread_count = n as u64;
55 let threadgroup_size = std::cmp::min(256, thread_count);
56
57 encoder.encode(
58 pipeline,
59 &[(0, input), (1, output)],
60 MTLSize::new(thread_count, 1, 1),
61 MTLSize::new(threadgroup_size, 1, 1),
62 );
63
64 Ok(())
65}
66
67pub fn dispatch_log_backward_f32(
70 encoder: &mut CommandEncoder,
71 registry: &mut KernelRegistry,
72 device: &metal::DeviceRef,
73 x: &MlxBuffer,
74 dy: &MlxBuffer,
75 dx: &MlxBuffer,
76) -> Result<()> {
77 let n = x.element_count();
78 if n == 0 {
79 return Err(MlxError::InvalidArgument(
80 "log_backward_f32: x must have at least one element".into(),
81 ));
82 }
83 for (label, buf) in [("dy", dy), ("dx", dx)] {
84 if buf.element_count() != n {
85 return Err(MlxError::InvalidArgument(format!(
86 "log_backward_f32: {label} element count {} != x element count {}",
87 buf.element_count(),
88 n
89 )));
90 }
91 }
92 for (label, buf) in [("x", x), ("dy", dy), ("dx", dx)] {
93 if buf.dtype() != DType::F32 {
94 return Err(MlxError::InvalidArgument(format!(
95 "log_backward_f32: {label} dtype {} not f32",
96 buf.dtype()
97 )));
98 }
99 }
100
101 let pipeline = registry.get_pipeline("log_backward_f32", device)?;
102 let thread_count = n as u64;
103 let threadgroup_size = std::cmp::min(256, thread_count);
104
105 encoder.encode(
106 pipeline,
107 &[(0, x), (1, dy), (2, dx)],
108 MTLSize::new(thread_count, 1, 1),
109 MTLSize::new(threadgroup_size, 1, 1),
110 );
111
112 Ok(())
113}