pjrt/
host_buffer.rs

1use std::ffi::c_void;
2use std::mem;
3use std::rc::Rc;
4
5use bon::bon;
6use pjrt_sys::{
7    PJRT_Buffer_MemoryLayout, PJRT_Buffer_Type, PJRT_Client_BufferFromHostBuffer_Args,
8    PJRT_HostBufferSemantics,
9    PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
10    PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
11    PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableZeroCopy,
12    PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kMutableZeroCopy,
13};
14
15use crate::event::Event;
16use crate::{
17    utils, Buffer, Client, Device, ElemType, Error, Memory, MemoryLayout, PrimitiveType, Result,
18    Type, F32, F64, I16, I32, I64, I8, U16, U32, U64, U8,
19};
20
21#[derive(Debug)]
22pub struct TypedHostBuffer<T: Type> {
23    data: Rc<Vec<T::ElemType>>,
24    dims: Vec<i64>,
25    layout: MemoryLayout,
26}
27
28impl<T: Type> TypedHostBuffer<T> {
29    pub fn builder() -> TypedHostBufferBuilder {
30        TypedHostBufferBuilder
31    }
32
33    pub fn scalar(data: T::ElemType) -> Self {
34        let data = vec![data];
35        let dims = vec![];
36        let layout = MemoryLayout::strides(vec![]);
37        Self {
38            data: Rc::new(data),
39            dims,
40            layout,
41        }
42    }
43
44    pub fn data(&self) -> &[T::ElemType] {
45        &self.data
46    }
47
48    pub fn dims(&self) -> &[i64] {
49        &self.dims
50    }
51
52    pub fn layout(&self) -> &MemoryLayout {
53        &self.layout
54    }
55
56    pub fn call_copy_to<D>(
57        &self,
58        config: &HostBufferCopyToConfig<D>,
59    ) -> Result<PJRT_Client_BufferFromHostBuffer_Args>
60    where
61        D: HostBufferCopyToDest,
62    {
63        let client = config.dest.client();
64        let mut args = PJRT_Client_BufferFromHostBuffer_Args::new();
65        args.client = client.ptr();
66        args.data = self.data.as_ptr() as *const c_void;
67        args.type_ = T::PRIMITIVE_TYPE as PJRT_Buffer_Type;
68        args.dims = self.dims.as_ptr();
69        args.num_dims = self.dims.len();
70        args.host_buffer_semantics =
71            HostBufferSemantics::ImmutableUntilTransferCompletes as PJRT_HostBufferSemantics;
72        if let Some(byte_strides) = &config.byte_strides {
73            args.byte_strides = byte_strides.as_ptr() as *const _;
74            args.num_byte_strides = byte_strides.len();
75        }
76        if let Some(device_layout) = &config.device_layout {
77            let mut device_layout = PJRT_Buffer_MemoryLayout::from(device_layout);
78            args.device_layout = &mut device_layout as *mut _;
79        }
80        config.dest.set_args(&mut args)?;
81        client.api().PJRT_Client_BufferFromHostBuffer(args)
82    }
83
84    pub fn copy_to_sync<D, C>(&self, config: C) -> Result<Buffer>
85    where
86        D: HostBufferCopyToDest,
87        C: IntoHostBufferCopyToConfig<D>,
88    {
89        let config = config.into_copy_to_config();
90        let client = config.dest.client();
91        let args = self.call_copy_to(&config)?;
92        let done_with_host_event = Event::wrap(client.api(), args.done_with_host_buffer);
93        done_with_host_event.wait()?;
94        let buf = Buffer::wrap(client, args.buffer);
95        let buf_ready_event = buf.ready_event()?;
96        buf_ready_event.wait()?;
97        Ok(buf)
98    }
99
100    pub async fn copy_to<D, C>(&self, config: C) -> Result<Buffer>
101    where
102        D: HostBufferCopyToDest,
103        C: IntoHostBufferCopyToConfig<D>,
104    {
105        let config = config.into_copy_to_config();
106        let client = config.dest.client();
107        let args = self.call_copy_to(&config)?;
108        let done_with_host_event = Event::wrap(client.api(), args.done_with_host_buffer);
109        done_with_host_event.await?;
110        let buf = Buffer::wrap(client, args.buffer);
111        let buf_ready_event = buf.ready_event()?;
112        buf_ready_event.await?;
113        Ok(buf)
114    }
115}
116
117macro_rules! impl_from_typed_buffer {
118    ($T:ident) => {
119        impl From<TypedHostBuffer<$T>> for HostBuffer {
120            fn from(buf: TypedHostBuffer<$T>) -> Self {
121                Self::$T(buf)
122            }
123        }
124    };
125}
126
127impl_from_typed_buffer!(F32);
128impl_from_typed_buffer![F64];
129impl_from_typed_buffer![I8];
130impl_from_typed_buffer![I16];
131impl_from_typed_buffer![I32];
132impl_from_typed_buffer![I64];
133impl_from_typed_buffer![U8];
134impl_from_typed_buffer![U16];
135impl_from_typed_buffer![U32];
136impl_from_typed_buffer![U64];
137
138#[derive(Debug)]
139pub enum HostBuffer {
140    F32(TypedHostBuffer<F32>),
141    F64(TypedHostBuffer<F64>),
142    I8(TypedHostBuffer<I8>),
143    I16(TypedHostBuffer<I16>),
144    I32(TypedHostBuffer<I32>),
145    I64(TypedHostBuffer<I64>),
146    U8(TypedHostBuffer<U8>),
147    U16(TypedHostBuffer<U16>),
148    U32(TypedHostBuffer<U32>),
149    U64(TypedHostBuffer<U64>),
150}
151
152impl HostBuffer {
153    pub fn builder() -> HostBufferBuilder {
154        HostBufferBuilder
155    }
156
157    pub fn scalar<E>(data: E) -> HostBuffer
158    where
159        E: ElemType,
160        Self: From<TypedHostBuffer<E::Type>>,
161    {
162        let buf = TypedHostBuffer::<E::Type>::scalar(data);
163        Self::from(buf)
164    }
165
166    pub fn dims(&self) -> &[i64] {
167        match self {
168            Self::F32(buf) => buf.dims(),
169            Self::F64(buf) => buf.dims(),
170            Self::I8(buf) => buf.dims(),
171            Self::I16(buf) => buf.dims(),
172            Self::I32(buf) => buf.dims(),
173            Self::I64(buf) => buf.dims(),
174            Self::U8(buf) => buf.dims(),
175            Self::U16(buf) => buf.dims(),
176            Self::U32(buf) => buf.dims(),
177            Self::U64(buf) => buf.dims(),
178        }
179    }
180
181    pub fn layout(&self) -> &MemoryLayout {
182        match self {
183            Self::F32(buf) => buf.layout(),
184            Self::F64(buf) => buf.layout(),
185            Self::I8(buf) => buf.layout(),
186            Self::I16(buf) => buf.layout(),
187            Self::I32(buf) => buf.layout(),
188            Self::I64(buf) => buf.layout(),
189            Self::U8(buf) => buf.layout(),
190            Self::U16(buf) => buf.layout(),
191            Self::U32(buf) => buf.layout(),
192            Self::U64(buf) => buf.layout(),
193        }
194    }
195
196    pub fn copy_to_sync<D, C>(&self, config: C) -> Result<Buffer>
197    where
198        D: HostBufferCopyToDest,
199        C: IntoHostBufferCopyToConfig<D>,
200    {
201        match self {
202            Self::F32(buf) => buf.copy_to_sync(config),
203            Self::F64(buf) => buf.copy_to_sync(config),
204            Self::I8(buf) => buf.copy_to_sync(config),
205            Self::I16(buf) => buf.copy_to_sync(config),
206            Self::I32(buf) => buf.copy_to_sync(config),
207            Self::I64(buf) => buf.copy_to_sync(config),
208            Self::U8(buf) => buf.copy_to_sync(config),
209            Self::U16(buf) => buf.copy_to_sync(config),
210            Self::U32(buf) => buf.copy_to_sync(config),
211            Self::U64(buf) => buf.copy_to_sync(config),
212        }
213    }
214
215    pub async fn copy_to<D, C>(&self, config: C) -> Result<Buffer>
216    where
217        D: HostBufferCopyToDest,
218        C: IntoHostBufferCopyToConfig<D>,
219    {
220        match self {
221            Self::F32(buf) => buf.copy_to(config).await,
222            Self::F64(buf) => buf.copy_to(config).await,
223            Self::I8(buf) => buf.copy_to(config).await,
224            Self::I16(buf) => buf.copy_to(config).await,
225            Self::I32(buf) => buf.copy_to(config).await,
226            Self::I64(buf) => buf.copy_to(config).await,
227            Self::U8(buf) => buf.copy_to(config).await,
228            Self::U16(buf) => buf.copy_to(config).await,
229            Self::U32(buf) => buf.copy_to(config).await,
230            Self::U64(buf) => buf.copy_to(config).await,
231        }
232    }
233}
234
235#[repr(i32)]
236#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
237#[allow(dead_code)]
238pub enum HostBufferSemantics {
239    /// The runtime may not hold references to `data` after the call to
240    /// `PJRT_Client_BufferFromHostBuffer` completes. The caller promises that
241    /// `data` is immutable and will not be freed only for the duration of the
242    /// PJRT_Client_BufferFromHostBuffer call.
243    ImmutableOnlyDuringCall =
244        PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableOnlyDuringCall as i32,
245
246    /// The runtime may hold onto `data` after the call to
247    /// `PJRT_Client_BufferFromHostBuffer`
248    /// returns while the runtime completes a transfer to the device. The caller
249    /// promises not to mutate or free `data` until the transfer completes, at
250    /// which point `done_with_host_buffer` will be triggered.
251    ImmutableUntilTransferCompletes =
252        PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes as i32,
253
254    /// The PjRtBuffer may alias `data` internally and the runtime may use the
255    /// `data` contents as long as the buffer is alive. The runtime promises not
256    /// to mutate contents of the buffer (i.e. it will not use it for aliased
257    /// output buffers). The caller promises to keep `data` alive and not to mutate
258    /// its contents as long as the buffer is alive; to notify the caller that the
259    /// buffer may be freed, the runtime will call `done_with_host_buffer` when the
260    /// PjRtBuffer is freed.
261    ImmutableZeroCopy = PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableZeroCopy as i32,
262
263    /// The PjRtBuffer may alias `data` internally and the runtime may use the
264    /// `data` contents as long as the buffer is alive. The runtime is allowed
265    /// to mutate contents of the buffer (i.e. use it for aliased output
266    /// buffers). The caller promises to keep `data` alive and not to mutate its
267    /// contents as long as the buffer is alive (otherwise it could be a data
268    /// race with the runtime); to notify the caller that the buffer may be
269    /// freed, the runtime will call `on_done_with_host_buffer` when the
270    /// PjRtBuffer is freed. On non-CPU platforms this acts identically to
271    /// kImmutableUntilTransferCompletes.
272    MutableZeroCopy = PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kMutableZeroCopy as i32,
273}
274
275pub trait HostBufferCopyToDest {
276    fn client(&self) -> &Client;
277    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()>;
278}
279
280impl HostBufferCopyToDest for Client {
281    fn client(&self) -> &Client {
282        self
283    }
284
285    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
286        args.device = self
287            .addressable_devices()
288            .first()
289            .ok_or(Error::NoAddressableDevice)?
290            .ptr;
291        Ok(())
292    }
293}
294
295impl<'a> HostBufferCopyToDest for &'a Client {
296    fn client(&self) -> &Client {
297        self
298    }
299
300    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
301        args.device = self
302            .addressable_devices()
303            .first()
304            .ok_or(Error::NoAddressableDevice)?
305            .ptr;
306        Ok(())
307    }
308}
309
310impl HostBufferCopyToDest for Device {
311    fn client(&self) -> &Client {
312        Device::client(self)
313    }
314
315    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
316        args.device = self.ptr;
317        Ok(())
318    }
319}
320
321impl<'a> HostBufferCopyToDest for &'a Device {
322    fn client(&self) -> &Client {
323        Device::client(self)
324    }
325
326    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
327        args.device = self.ptr;
328        Ok(())
329    }
330}
331
332impl HostBufferCopyToDest for Memory {
333    fn client(&self) -> &Client {
334        Memory::client(self)
335    }
336
337    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
338        args.memory = self.ptr;
339        Ok(())
340    }
341}
342
343impl<'a> HostBufferCopyToDest for &'a Memory {
344    fn client(&self) -> &Client {
345        Memory::client(self)
346    }
347
348    fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
349        args.memory = self.ptr;
350        Ok(())
351    }
352}
353
354pub struct HostBufferCopyToConfig<D>
355where
356    D: HostBufferCopyToDest,
357{
358    dest: D,
359    byte_strides: Option<Vec<i64>>,
360    device_layout: Option<MemoryLayout>,
361}
362
363impl<D> HostBufferCopyToConfig<D>
364where
365    D: HostBufferCopyToDest,
366{
367    pub fn new(dest: D) -> Self {
368        Self {
369            dest,
370            byte_strides: None,
371            device_layout: None,
372        }
373    }
374
375    pub fn byte_strides(mut self, byte_strides: Vec<i64>) -> Self {
376        self.byte_strides = Some(byte_strides);
377        self
378    }
379
380    pub fn device_layout(mut self, device_layout: MemoryLayout) -> Self {
381        self.device_layout = Some(device_layout);
382        self
383    }
384}
385
386mod private {
387    use crate::host_buffer::{HostBufferCopyToConfig, HostBufferCopyToDest};
388    use crate::MemoryLayout;
389
390    pub trait Argument {
391        type Repr;
392    }
393
394    pub trait ToConfig<A, D>
395    where
396        D: HostBufferCopyToDest,
397    {
398        fn into_config(self) -> HostBufferCopyToConfig<D>;
399    }
400
401    impl<D> Argument for D
402    where
403        D: HostBufferCopyToDest,
404    {
405        type Repr = (D,);
406    }
407
408    impl<D> ToConfig<(D,), D> for D
409    where
410        D: HostBufferCopyToDest,
411    {
412        fn into_config(self) -> HostBufferCopyToConfig<D> {
413            HostBufferCopyToConfig::new(self)
414        }
415    }
416
417    impl<D, B> Argument for (D, B)
418    where
419        D: HostBufferCopyToDest,
420        B: Into<Vec<i64>>,
421    {
422        type Repr = (D, B);
423    }
424
425    impl<D, B> ToConfig<(D, B), D> for (D, B)
426    where
427        D: HostBufferCopyToDest,
428        B: Into<Vec<i64>>,
429    {
430        fn into_config(self) -> HostBufferCopyToConfig<D> {
431            HostBufferCopyToConfig::new(self.0).byte_strides(self.1.into())
432        }
433    }
434
435    impl<D> Argument for (D, MemoryLayout)
436    where
437        D: HostBufferCopyToDest,
438    {
439        type Repr = (D, MemoryLayout);
440    }
441
442    impl<D> ToConfig<(D, MemoryLayout), D> for (D, MemoryLayout)
443    where
444        D: HostBufferCopyToDest,
445    {
446        fn into_config(self) -> HostBufferCopyToConfig<D> {
447            HostBufferCopyToConfig::new(self.0).device_layout(self.1)
448        }
449    }
450
451    impl<'a, D> Argument for (D, &'a MemoryLayout)
452    where
453        D: HostBufferCopyToDest,
454    {
455        type Repr = (D, &'a MemoryLayout);
456    }
457
458    impl<'a, D> ToConfig<(D, &'a MemoryLayout), D> for (D, &'a MemoryLayout)
459    where
460        D: HostBufferCopyToDest,
461    {
462        fn into_config(self) -> HostBufferCopyToConfig<D> {
463            HostBufferCopyToConfig::new(self.0).device_layout(self.1.clone())
464        }
465    }
466
467    impl<D, B, M> Argument for (D, B, M)
468    where
469        D: HostBufferCopyToDest,
470        B: Into<Vec<i64>>,
471        M: Into<MemoryLayout>,
472    {
473        type Repr = (D, B, M);
474    }
475
476    impl<D, B, M> ToConfig<(D, B, M), D> for (D, B, M)
477    where
478        D: HostBufferCopyToDest,
479        B: Into<Vec<i64>>,
480        M: Into<MemoryLayout>,
481    {
482        fn into_config(self) -> HostBufferCopyToConfig<D> {
483            HostBufferCopyToConfig::new(self.0)
484                .byte_strides(self.1.into())
485                .device_layout(self.2.into())
486        }
487    }
488}
489
490pub trait IntoHostBufferCopyToConfig<D>
491where
492    D: HostBufferCopyToDest,
493{
494    fn into_copy_to_config(self) -> HostBufferCopyToConfig<D>;
495}
496
497impl<T, D> IntoHostBufferCopyToConfig<D> for T
498where
499    T: private::Argument + private::ToConfig<T::Repr, D>,
500    D: HostBufferCopyToDest,
501{
502    fn into_copy_to_config(self) -> HostBufferCopyToConfig<D> {
503        self.into_config()
504    }
505}
506
507#[derive(Debug)]
508pub struct TypedHostBufferBuilder;
509
510#[bon]
511impl TypedHostBufferBuilder {
512    #[builder(finish_fn = build)]
513    pub fn data<E>(
514        &self,
515        #[builder(start_fn, into)] data: Vec<E>,
516        #[builder(into)] dims: Option<Vec<i64>>,
517        #[builder] layout: Option<MemoryLayout>,
518    ) -> TypedHostBuffer<E::Type>
519    where
520        E: ElemType,
521    {
522        let dims = dims.unwrap_or_else(|| vec![data.len() as i64]);
523        let layout = layout
524            .unwrap_or_else(|| MemoryLayout::strides(utils::byte_strides(&dims, E::Type::SIZE)));
525        TypedHostBuffer {
526            data: Rc::new(data),
527            dims,
528            layout,
529        }
530    }
531
532    #[builder(finish_fn = build)]
533    pub fn bytes<T>(
534        &self,
535        #[builder(start_fn, into)] bytes: Vec<u8>,
536        #[builder(into)] dims: Option<Vec<i64>>,
537        #[builder] layout: Option<MemoryLayout>,
538    ) -> TypedHostBuffer<T>
539    where
540        T: Type,
541    {
542        let length = bytes.len() / T::SIZE;
543        let capacity = bytes.capacity() / T::SIZE;
544        let ptr = bytes.as_ptr() as *mut T::ElemType;
545        let data = unsafe { Vec::from_raw_parts(ptr, length, capacity) };
546        mem::forget(bytes);
547        let dims = dims.unwrap_or_else(|| vec![length as i64]);
548        assert!(dims.iter().product::<i64>() == length as i64);
549        let layout =
550            layout.unwrap_or_else(|| MemoryLayout::strides(utils::byte_strides(&dims, T::SIZE)));
551        TypedHostBuffer {
552            data: Rc::new(data),
553            dims,
554            layout,
555        }
556    }
557}
558
559#[derive(Debug)]
560pub struct HostBufferBuilder;
561
562#[bon]
563impl HostBufferBuilder {
564    #[builder(finish_fn = build)]
565    pub fn data<E>(
566        &self,
567        #[builder(start_fn, into)] data: Vec<E>,
568        #[builder(into)] dims: Option<Vec<i64>>,
569        #[builder] layout: Option<MemoryLayout>,
570    ) -> HostBuffer
571    where
572        E: ElemType,
573        HostBuffer: From<TypedHostBuffer<E::Type>>,
574    {
575        let buf = TypedHostBufferBuilder
576            .data::<E>(data)
577            .maybe_dims(dims)
578            .maybe_layout(layout)
579            .build();
580        HostBuffer::from(buf)
581    }
582
583    #[builder(finish_fn = build)]
584    pub fn bytes(
585        &self,
586        #[builder(start_fn)] bytes: Vec<u8>,
587        #[builder(start_fn)] ty: PrimitiveType,
588        #[builder(into)] dims: Option<Vec<i64>>,
589        #[builder] layout: Option<MemoryLayout>,
590    ) -> Result<HostBuffer> {
591        match ty {
592            PrimitiveType::F32 => Ok(HostBuffer::F32(
593                TypedHostBufferBuilder
594                    .bytes::<F32>(bytes)
595                    .maybe_dims(dims)
596                    .maybe_layout(layout)
597                    .build(),
598            )),
599            PrimitiveType::F64 => Ok(HostBuffer::F64(
600                TypedHostBufferBuilder
601                    .bytes::<F64>(bytes)
602                    .maybe_dims(dims)
603                    .maybe_layout(layout)
604                    .build(),
605            )),
606            PrimitiveType::S8 => Ok(HostBuffer::I8(
607                TypedHostBufferBuilder
608                    .bytes::<I8>(bytes)
609                    .maybe_dims(dims)
610                    .maybe_layout(layout)
611                    .build(),
612            )),
613            PrimitiveType::S16 => Ok(HostBuffer::I16(
614                TypedHostBufferBuilder
615                    .bytes::<I16>(bytes)
616                    .maybe_dims(dims)
617                    .maybe_layout(layout)
618                    .build(),
619            )),
620            PrimitiveType::S32 => Ok(HostBuffer::I32(
621                TypedHostBufferBuilder
622                    .bytes::<I32>(bytes)
623                    .maybe_dims(dims)
624                    .maybe_layout(layout)
625                    .build(),
626            )),
627            PrimitiveType::S64 => Ok(HostBuffer::I64(
628                TypedHostBufferBuilder
629                    .bytes::<I64>(bytes)
630                    .maybe_dims(dims)
631                    .maybe_layout(layout)
632                    .build(),
633            )),
634            PrimitiveType::U8 => Ok(HostBuffer::U8(
635                TypedHostBufferBuilder
636                    .bytes::<U8>(bytes)
637                    .maybe_dims(dims)
638                    .maybe_layout(layout)
639                    .build(),
640            )),
641            PrimitiveType::U16 => Ok(HostBuffer::U16(
642                TypedHostBufferBuilder
643                    .bytes::<U16>(bytes)
644                    .maybe_dims(dims)
645                    .maybe_layout(layout)
646                    .build(),
647            )),
648            PrimitiveType::U32 => Ok(HostBuffer::U32(
649                TypedHostBufferBuilder
650                    .bytes::<U32>(bytes)
651                    .maybe_dims(dims)
652                    .maybe_layout(layout)
653                    .build(),
654            )),
655            PrimitiveType::U64 => Ok(HostBuffer::U64(
656                TypedHostBufferBuilder
657                    .bytes::<U64>(bytes)
658                    .maybe_dims(dims)
659                    .maybe_layout(layout)
660                    .build(),
661            )),
662            _ => Err(Error::NotSupportedType(ty)),
663        }
664    }
665}