1use std::ffi::c_void;
2use std::mem;
3use std::rc::Rc;
4
5use bon::bon;
6use pjrt_sys::{
7 PJRT_Buffer_MemoryLayout, PJRT_Buffer_Type, PJRT_Client_BufferFromHostBuffer_Args,
8 PJRT_HostBufferSemantics,
9 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
10 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
11 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableZeroCopy,
12 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kMutableZeroCopy,
13};
14
15use crate::event::Event;
16use crate::{
17 utils, Buffer, Client, Device, ElemType, Error, Memory, MemoryLayout, PrimitiveType, Result,
18 Type, F32, F64, I16, I32, I64, I8, U16, U32, U64, U8,
19};
20
21#[derive(Debug)]
22pub struct TypedHostBuffer<T: Type> {
23 data: Rc<Vec<T::ElemType>>,
24 dims: Vec<i64>,
25 layout: MemoryLayout,
26}
27
28impl<T: Type> TypedHostBuffer<T> {
29 pub fn builder() -> TypedHostBufferBuilder {
30 TypedHostBufferBuilder
31 }
32
33 pub fn scalar(data: T::ElemType) -> Self {
34 let data = vec![data];
35 let dims = vec![];
36 let layout = MemoryLayout::strides(vec![]);
37 Self {
38 data: Rc::new(data),
39 dims,
40 layout,
41 }
42 }
43
44 pub fn data(&self) -> &[T::ElemType] {
45 &self.data
46 }
47
48 pub fn dims(&self) -> &[i64] {
49 &self.dims
50 }
51
52 pub fn layout(&self) -> &MemoryLayout {
53 &self.layout
54 }
55
56 pub fn call_copy_to<D>(
57 &self,
58 config: &HostBufferCopyToConfig<D>,
59 ) -> Result<PJRT_Client_BufferFromHostBuffer_Args>
60 where
61 D: HostBufferCopyToDest,
62 {
63 let client = config.dest.client();
64 let mut args = PJRT_Client_BufferFromHostBuffer_Args::new();
65 args.client = client.ptr();
66 args.data = self.data.as_ptr() as *const c_void;
67 args.type_ = T::PRIMITIVE_TYPE as PJRT_Buffer_Type;
68 args.dims = self.dims.as_ptr();
69 args.num_dims = self.dims.len();
70 args.host_buffer_semantics =
71 HostBufferSemantics::ImmutableUntilTransferCompletes as PJRT_HostBufferSemantics;
72 if let Some(byte_strides) = &config.byte_strides {
73 args.byte_strides = byte_strides.as_ptr() as *const _;
74 args.num_byte_strides = byte_strides.len();
75 }
76 if let Some(device_layout) = &config.device_layout {
77 let mut device_layout = PJRT_Buffer_MemoryLayout::from(device_layout);
78 args.device_layout = &mut device_layout as *mut _;
79 }
80 config.dest.set_args(&mut args)?;
81 client.api().PJRT_Client_BufferFromHostBuffer(args)
82 }
83
84 pub fn copy_to_sync<D, C>(&self, config: C) -> Result<Buffer>
85 where
86 D: HostBufferCopyToDest,
87 C: IntoHostBufferCopyToConfig<D>,
88 {
89 let config = config.into_copy_to_config();
90 let client = config.dest.client();
91 let args = self.call_copy_to(&config)?;
92 let done_with_host_event = Event::wrap(client.api(), args.done_with_host_buffer);
93 done_with_host_event.wait()?;
94 let buf = Buffer::wrap(client, args.buffer);
95 let buf_ready_event = buf.ready_event()?;
96 buf_ready_event.wait()?;
97 Ok(buf)
98 }
99
100 pub async fn copy_to<D, C>(&self, config: C) -> Result<Buffer>
101 where
102 D: HostBufferCopyToDest,
103 C: IntoHostBufferCopyToConfig<D>,
104 {
105 let config = config.into_copy_to_config();
106 let client = config.dest.client();
107 let args = self.call_copy_to(&config)?;
108 let done_with_host_event = Event::wrap(client.api(), args.done_with_host_buffer);
109 done_with_host_event.await?;
110 let buf = Buffer::wrap(client, args.buffer);
111 let buf_ready_event = buf.ready_event()?;
112 buf_ready_event.await?;
113 Ok(buf)
114 }
115}
116
117macro_rules! impl_from_typed_buffer {
118 ($T:ident) => {
119 impl From<TypedHostBuffer<$T>> for HostBuffer {
120 fn from(buf: TypedHostBuffer<$T>) -> Self {
121 Self::$T(buf)
122 }
123 }
124 };
125}
126
127impl_from_typed_buffer!(F32);
128impl_from_typed_buffer![F64];
129impl_from_typed_buffer![I8];
130impl_from_typed_buffer![I16];
131impl_from_typed_buffer![I32];
132impl_from_typed_buffer![I64];
133impl_from_typed_buffer![U8];
134impl_from_typed_buffer![U16];
135impl_from_typed_buffer![U32];
136impl_from_typed_buffer![U64];
137
138#[derive(Debug)]
139pub enum HostBuffer {
140 F32(TypedHostBuffer<F32>),
141 F64(TypedHostBuffer<F64>),
142 I8(TypedHostBuffer<I8>),
143 I16(TypedHostBuffer<I16>),
144 I32(TypedHostBuffer<I32>),
145 I64(TypedHostBuffer<I64>),
146 U8(TypedHostBuffer<U8>),
147 U16(TypedHostBuffer<U16>),
148 U32(TypedHostBuffer<U32>),
149 U64(TypedHostBuffer<U64>),
150}
151
152impl HostBuffer {
153 pub fn builder() -> HostBufferBuilder {
154 HostBufferBuilder
155 }
156
157 pub fn scalar<E>(data: E) -> HostBuffer
158 where
159 E: ElemType,
160 Self: From<TypedHostBuffer<E::Type>>,
161 {
162 let buf = TypedHostBuffer::<E::Type>::scalar(data);
163 Self::from(buf)
164 }
165
166 pub fn dims(&self) -> &[i64] {
167 match self {
168 Self::F32(buf) => buf.dims(),
169 Self::F64(buf) => buf.dims(),
170 Self::I8(buf) => buf.dims(),
171 Self::I16(buf) => buf.dims(),
172 Self::I32(buf) => buf.dims(),
173 Self::I64(buf) => buf.dims(),
174 Self::U8(buf) => buf.dims(),
175 Self::U16(buf) => buf.dims(),
176 Self::U32(buf) => buf.dims(),
177 Self::U64(buf) => buf.dims(),
178 }
179 }
180
181 pub fn layout(&self) -> &MemoryLayout {
182 match self {
183 Self::F32(buf) => buf.layout(),
184 Self::F64(buf) => buf.layout(),
185 Self::I8(buf) => buf.layout(),
186 Self::I16(buf) => buf.layout(),
187 Self::I32(buf) => buf.layout(),
188 Self::I64(buf) => buf.layout(),
189 Self::U8(buf) => buf.layout(),
190 Self::U16(buf) => buf.layout(),
191 Self::U32(buf) => buf.layout(),
192 Self::U64(buf) => buf.layout(),
193 }
194 }
195
196 pub fn copy_to_sync<D, C>(&self, config: C) -> Result<Buffer>
197 where
198 D: HostBufferCopyToDest,
199 C: IntoHostBufferCopyToConfig<D>,
200 {
201 match self {
202 Self::F32(buf) => buf.copy_to_sync(config),
203 Self::F64(buf) => buf.copy_to_sync(config),
204 Self::I8(buf) => buf.copy_to_sync(config),
205 Self::I16(buf) => buf.copy_to_sync(config),
206 Self::I32(buf) => buf.copy_to_sync(config),
207 Self::I64(buf) => buf.copy_to_sync(config),
208 Self::U8(buf) => buf.copy_to_sync(config),
209 Self::U16(buf) => buf.copy_to_sync(config),
210 Self::U32(buf) => buf.copy_to_sync(config),
211 Self::U64(buf) => buf.copy_to_sync(config),
212 }
213 }
214
215 pub async fn copy_to<D, C>(&self, config: C) -> Result<Buffer>
216 where
217 D: HostBufferCopyToDest,
218 C: IntoHostBufferCopyToConfig<D>,
219 {
220 match self {
221 Self::F32(buf) => buf.copy_to(config).await,
222 Self::F64(buf) => buf.copy_to(config).await,
223 Self::I8(buf) => buf.copy_to(config).await,
224 Self::I16(buf) => buf.copy_to(config).await,
225 Self::I32(buf) => buf.copy_to(config).await,
226 Self::I64(buf) => buf.copy_to(config).await,
227 Self::U8(buf) => buf.copy_to(config).await,
228 Self::U16(buf) => buf.copy_to(config).await,
229 Self::U32(buf) => buf.copy_to(config).await,
230 Self::U64(buf) => buf.copy_to(config).await,
231 }
232 }
233}
234
235#[repr(i32)]
236#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
237#[allow(dead_code)]
238pub enum HostBufferSemantics {
239 ImmutableOnlyDuringCall =
244 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableOnlyDuringCall as i32,
245
246 ImmutableUntilTransferCompletes =
252 PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes as i32,
253
254 ImmutableZeroCopy = PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kImmutableZeroCopy as i32,
262
263 MutableZeroCopy = PJRT_HostBufferSemantics_PJRT_HostBufferSemantics_kMutableZeroCopy as i32,
273}
274
275pub trait HostBufferCopyToDest {
276 fn client(&self) -> &Client;
277 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()>;
278}
279
280impl HostBufferCopyToDest for Client {
281 fn client(&self) -> &Client {
282 self
283 }
284
285 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
286 args.device = self
287 .addressable_devices()
288 .first()
289 .ok_or(Error::NoAddressableDevice)?
290 .ptr;
291 Ok(())
292 }
293}
294
295impl<'a> HostBufferCopyToDest for &'a Client {
296 fn client(&self) -> &Client {
297 self
298 }
299
300 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
301 args.device = self
302 .addressable_devices()
303 .first()
304 .ok_or(Error::NoAddressableDevice)?
305 .ptr;
306 Ok(())
307 }
308}
309
310impl HostBufferCopyToDest for Device {
311 fn client(&self) -> &Client {
312 Device::client(self)
313 }
314
315 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
316 args.device = self.ptr;
317 Ok(())
318 }
319}
320
321impl<'a> HostBufferCopyToDest for &'a Device {
322 fn client(&self) -> &Client {
323 Device::client(self)
324 }
325
326 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
327 args.device = self.ptr;
328 Ok(())
329 }
330}
331
332impl HostBufferCopyToDest for Memory {
333 fn client(&self) -> &Client {
334 Memory::client(self)
335 }
336
337 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
338 args.memory = self.ptr;
339 Ok(())
340 }
341}
342
343impl<'a> HostBufferCopyToDest for &'a Memory {
344 fn client(&self) -> &Client {
345 Memory::client(self)
346 }
347
348 fn set_args(&self, args: &mut PJRT_Client_BufferFromHostBuffer_Args) -> Result<()> {
349 args.memory = self.ptr;
350 Ok(())
351 }
352}
353
354pub struct HostBufferCopyToConfig<D>
355where
356 D: HostBufferCopyToDest,
357{
358 dest: D,
359 byte_strides: Option<Vec<i64>>,
360 device_layout: Option<MemoryLayout>,
361}
362
363impl<D> HostBufferCopyToConfig<D>
364where
365 D: HostBufferCopyToDest,
366{
367 pub fn new(dest: D) -> Self {
368 Self {
369 dest,
370 byte_strides: None,
371 device_layout: None,
372 }
373 }
374
375 pub fn byte_strides(mut self, byte_strides: Vec<i64>) -> Self {
376 self.byte_strides = Some(byte_strides);
377 self
378 }
379
380 pub fn device_layout(mut self, device_layout: MemoryLayout) -> Self {
381 self.device_layout = Some(device_layout);
382 self
383 }
384}
385
386mod private {
387 use crate::host_buffer::{HostBufferCopyToConfig, HostBufferCopyToDest};
388 use crate::MemoryLayout;
389
390 pub trait Argument {
391 type Repr;
392 }
393
394 pub trait ToConfig<A, D>
395 where
396 D: HostBufferCopyToDest,
397 {
398 fn into_config(self) -> HostBufferCopyToConfig<D>;
399 }
400
401 impl<D> Argument for D
402 where
403 D: HostBufferCopyToDest,
404 {
405 type Repr = (D,);
406 }
407
408 impl<D> ToConfig<(D,), D> for D
409 where
410 D: HostBufferCopyToDest,
411 {
412 fn into_config(self) -> HostBufferCopyToConfig<D> {
413 HostBufferCopyToConfig::new(self)
414 }
415 }
416
417 impl<D, B> Argument for (D, B)
418 where
419 D: HostBufferCopyToDest,
420 B: Into<Vec<i64>>,
421 {
422 type Repr = (D, B);
423 }
424
425 impl<D, B> ToConfig<(D, B), D> for (D, B)
426 where
427 D: HostBufferCopyToDest,
428 B: Into<Vec<i64>>,
429 {
430 fn into_config(self) -> HostBufferCopyToConfig<D> {
431 HostBufferCopyToConfig::new(self.0).byte_strides(self.1.into())
432 }
433 }
434
435 impl<D> Argument for (D, MemoryLayout)
436 where
437 D: HostBufferCopyToDest,
438 {
439 type Repr = (D, MemoryLayout);
440 }
441
442 impl<D> ToConfig<(D, MemoryLayout), D> for (D, MemoryLayout)
443 where
444 D: HostBufferCopyToDest,
445 {
446 fn into_config(self) -> HostBufferCopyToConfig<D> {
447 HostBufferCopyToConfig::new(self.0).device_layout(self.1)
448 }
449 }
450
451 impl<'a, D> Argument for (D, &'a MemoryLayout)
452 where
453 D: HostBufferCopyToDest,
454 {
455 type Repr = (D, &'a MemoryLayout);
456 }
457
458 impl<'a, D> ToConfig<(D, &'a MemoryLayout), D> for (D, &'a MemoryLayout)
459 where
460 D: HostBufferCopyToDest,
461 {
462 fn into_config(self) -> HostBufferCopyToConfig<D> {
463 HostBufferCopyToConfig::new(self.0).device_layout(self.1.clone())
464 }
465 }
466
467 impl<D, B, M> Argument for (D, B, M)
468 where
469 D: HostBufferCopyToDest,
470 B: Into<Vec<i64>>,
471 M: Into<MemoryLayout>,
472 {
473 type Repr = (D, B, M);
474 }
475
476 impl<D, B, M> ToConfig<(D, B, M), D> for (D, B, M)
477 where
478 D: HostBufferCopyToDest,
479 B: Into<Vec<i64>>,
480 M: Into<MemoryLayout>,
481 {
482 fn into_config(self) -> HostBufferCopyToConfig<D> {
483 HostBufferCopyToConfig::new(self.0)
484 .byte_strides(self.1.into())
485 .device_layout(self.2.into())
486 }
487 }
488}
489
490pub trait IntoHostBufferCopyToConfig<D>
491where
492 D: HostBufferCopyToDest,
493{
494 fn into_copy_to_config(self) -> HostBufferCopyToConfig<D>;
495}
496
497impl<T, D> IntoHostBufferCopyToConfig<D> for T
498where
499 T: private::Argument + private::ToConfig<T::Repr, D>,
500 D: HostBufferCopyToDest,
501{
502 fn into_copy_to_config(self) -> HostBufferCopyToConfig<D> {
503 self.into_config()
504 }
505}
506
507#[derive(Debug)]
508pub struct TypedHostBufferBuilder;
509
510#[bon]
511impl TypedHostBufferBuilder {
512 #[builder(finish_fn = build)]
513 pub fn data<E>(
514 &self,
515 #[builder(start_fn, into)] data: Vec<E>,
516 #[builder(into)] dims: Option<Vec<i64>>,
517 #[builder] layout: Option<MemoryLayout>,
518 ) -> TypedHostBuffer<E::Type>
519 where
520 E: ElemType,
521 {
522 let dims = dims.unwrap_or_else(|| vec![data.len() as i64]);
523 let layout = layout
524 .unwrap_or_else(|| MemoryLayout::strides(utils::byte_strides(&dims, E::Type::SIZE)));
525 TypedHostBuffer {
526 data: Rc::new(data),
527 dims,
528 layout,
529 }
530 }
531
532 #[builder(finish_fn = build)]
533 pub fn bytes<T>(
534 &self,
535 #[builder(start_fn, into)] bytes: Vec<u8>,
536 #[builder(into)] dims: Option<Vec<i64>>,
537 #[builder] layout: Option<MemoryLayout>,
538 ) -> TypedHostBuffer<T>
539 where
540 T: Type,
541 {
542 let length = bytes.len() / T::SIZE;
543 let capacity = bytes.capacity() / T::SIZE;
544 let ptr = bytes.as_ptr() as *mut T::ElemType;
545 let data = unsafe { Vec::from_raw_parts(ptr, length, capacity) };
546 mem::forget(bytes);
547 let dims = dims.unwrap_or_else(|| vec![length as i64]);
548 assert!(dims.iter().product::<i64>() == length as i64);
549 let layout =
550 layout.unwrap_or_else(|| MemoryLayout::strides(utils::byte_strides(&dims, T::SIZE)));
551 TypedHostBuffer {
552 data: Rc::new(data),
553 dims,
554 layout,
555 }
556 }
557}
558
559#[derive(Debug)]
560pub struct HostBufferBuilder;
561
562#[bon]
563impl HostBufferBuilder {
564 #[builder(finish_fn = build)]
565 pub fn data<E>(
566 &self,
567 #[builder(start_fn, into)] data: Vec<E>,
568 #[builder(into)] dims: Option<Vec<i64>>,
569 #[builder] layout: Option<MemoryLayout>,
570 ) -> HostBuffer
571 where
572 E: ElemType,
573 HostBuffer: From<TypedHostBuffer<E::Type>>,
574 {
575 let buf = TypedHostBufferBuilder
576 .data::<E>(data)
577 .maybe_dims(dims)
578 .maybe_layout(layout)
579 .build();
580 HostBuffer::from(buf)
581 }
582
583 #[builder(finish_fn = build)]
584 pub fn bytes(
585 &self,
586 #[builder(start_fn)] bytes: Vec<u8>,
587 #[builder(start_fn)] ty: PrimitiveType,
588 #[builder(into)] dims: Option<Vec<i64>>,
589 #[builder] layout: Option<MemoryLayout>,
590 ) -> Result<HostBuffer> {
591 match ty {
592 PrimitiveType::F32 => Ok(HostBuffer::F32(
593 TypedHostBufferBuilder
594 .bytes::<F32>(bytes)
595 .maybe_dims(dims)
596 .maybe_layout(layout)
597 .build(),
598 )),
599 PrimitiveType::F64 => Ok(HostBuffer::F64(
600 TypedHostBufferBuilder
601 .bytes::<F64>(bytes)
602 .maybe_dims(dims)
603 .maybe_layout(layout)
604 .build(),
605 )),
606 PrimitiveType::S8 => Ok(HostBuffer::I8(
607 TypedHostBufferBuilder
608 .bytes::<I8>(bytes)
609 .maybe_dims(dims)
610 .maybe_layout(layout)
611 .build(),
612 )),
613 PrimitiveType::S16 => Ok(HostBuffer::I16(
614 TypedHostBufferBuilder
615 .bytes::<I16>(bytes)
616 .maybe_dims(dims)
617 .maybe_layout(layout)
618 .build(),
619 )),
620 PrimitiveType::S32 => Ok(HostBuffer::I32(
621 TypedHostBufferBuilder
622 .bytes::<I32>(bytes)
623 .maybe_dims(dims)
624 .maybe_layout(layout)
625 .build(),
626 )),
627 PrimitiveType::S64 => Ok(HostBuffer::I64(
628 TypedHostBufferBuilder
629 .bytes::<I64>(bytes)
630 .maybe_dims(dims)
631 .maybe_layout(layout)
632 .build(),
633 )),
634 PrimitiveType::U8 => Ok(HostBuffer::U8(
635 TypedHostBufferBuilder
636 .bytes::<U8>(bytes)
637 .maybe_dims(dims)
638 .maybe_layout(layout)
639 .build(),
640 )),
641 PrimitiveType::U16 => Ok(HostBuffer::U16(
642 TypedHostBufferBuilder
643 .bytes::<U16>(bytes)
644 .maybe_dims(dims)
645 .maybe_layout(layout)
646 .build(),
647 )),
648 PrimitiveType::U32 => Ok(HostBuffer::U32(
649 TypedHostBufferBuilder
650 .bytes::<U32>(bytes)
651 .maybe_dims(dims)
652 .maybe_layout(layout)
653 .build(),
654 )),
655 PrimitiveType::U64 => Ok(HostBuffer::U64(
656 TypedHostBufferBuilder
657 .bytes::<U64>(bytes)
658 .maybe_dims(dims)
659 .maybe_layout(layout)
660 .build(),
661 )),
662 _ => Err(Error::NotSupportedType(ty)),
663 }
664 }
665}