Skip to main content

07_optimizer_and_state/
07_optimizer_and_state.rs

1use apple_metal::{resource_options, MetalBuffer, MetalDevice};
2use apple_mps::{
3    data_type, nn_regularization_type, state_batch_increment_read_count, state_resource_type,
4    NNOptimizerStochasticGradientDescent, State, StateResourceList, Vector, VectorDescriptor,
5};
6
7fn as_bytes<T>(values: &[T]) -> &[u8] {
8    unsafe {
9        core::slice::from_raw_parts(values.as_ptr().cast::<u8>(), core::mem::size_of_val(values))
10    }
11}
12
13fn buffer_with_f32_values(device: &MetalDevice, values: &[f32]) -> MetalBuffer {
14    let buffer = device
15        .new_buffer(
16            core::mem::size_of_val(values),
17            resource_options::STORAGE_MODE_SHARED,
18        )
19        .expect("buffer");
20    let _ = buffer.write_bytes(as_bytes(values));
21    buffer
22}
23
24fn read_f32_values(buffer: &MetalBuffer, len: usize) -> Vec<f32> {
25    let ptr = buffer.contents().expect("buffer contents").cast::<f32>();
26    unsafe { core::slice::from_raw_parts(ptr, len).to_vec() }
27}
28
29fn vector_with_values(device: &MetalDevice, values: &[f32]) -> (MetalBuffer, Vector) {
30    let buffer = buffer_with_f32_values(device, values);
31    let descriptor =
32        VectorDescriptor::contiguous(values.len(), data_type::FLOAT32).expect("vector desc");
33    let vector = Vector::new_with_buffer(&buffer, descriptor).expect("vector");
34    (buffer, vector)
35}
36
37fn main() {
38    let device = MetalDevice::system_default().expect("no Metal device available");
39    let queue = device.new_command_queue().expect("command queue");
40
41    let command_buffer = queue.new_command_buffer().expect("command buffer");
42    let temporary_a =
43        State::temporary_with_buffer_size(&command_buffer, 32).expect("temporary state a");
44    let temporary_b =
45        State::temporary_with_buffer_size(&command_buffer, 64).expect("temporary state b");
46    let unique_count =
47        state_batch_increment_read_count(&[&temporary_a, &temporary_a, &temporary_b], 1);
48    assert_eq!(unique_count, 2);
49    assert_eq!(
50        temporary_a.resource_type_at_index(0),
51        state_resource_type::BUFFER
52    );
53    command_buffer.commit();
54    command_buffer.wait_until_completed();
55
56    let resource_list = StateResourceList::new().expect("resource list");
57    resource_list.append_buffer(16);
58    let persistent_state =
59        State::new_with_resource_list(&device, &resource_list).expect("persistent state");
60    assert_eq!(persistent_state.resource_count(), 1);
61    assert_eq!(persistent_state.buffer_size_at_index(0), 16);
62
63    let (gradient_buffer, gradient_vector) = vector_with_values(&device, &[0.1, -0.2]);
64    let (values_buffer, values_vector) = vector_with_values(&device, &[1.0, -1.0]);
65    let (result_buffer, result_vector) = vector_with_values(&device, &[0.0, 0.0]);
66    let optimizer = NNOptimizerStochasticGradientDescent::new(&device, 0.5).expect("sgd");
67    let base = optimizer.as_optimizer().expect("optimizer base");
68    assert_eq!(base.regularization_type(), nn_regularization_type::NONE);
69
70    let command_buffer = queue.new_command_buffer().expect("command buffer");
71    optimizer.encode_vector(
72        &command_buffer,
73        &gradient_vector,
74        &values_vector,
75        None,
76        &result_vector,
77    );
78    command_buffer.commit();
79    command_buffer.wait_until_completed();
80
81    let _ = gradient_buffer;
82    let _ = values_buffer;
83    let output = read_f32_values(&result_buffer, 2);
84    println!("{output:?}");
85}