use super::{
handle_status, FromPrimitive, Literal, NativeType, PrimitiveType, Shape, XlaComputation, XlaOp,
};
use crate::{c_lib, Error, Result};
use std::rc::Rc;
pub(super) struct XlaBuilderInternal(c_lib::xla_builder);
#[derive(Clone)]
pub struct XlaBuilder(Rc<XlaBuilderInternal>);
impl XlaBuilder {
pub fn new(name: &str) -> XlaBuilder {
let name = std::ffi::CString::new(name).unwrap();
let xla_builder = unsafe { c_lib::xla_builder_create(name.as_ptr()) };
XlaBuilder(Rc::new(XlaBuilderInternal(xla_builder)))
}
fn ptr(&self) -> c_lib::xla_builder {
self.0 .0
}
pub fn build(&self, op: &XlaOp) -> Result<XlaComputation> {
let mut result: c_lib::xla_computation = std::ptr::null_mut();
let status = unsafe { c_lib::build(self.ptr(), op.op, &mut result) };
handle_status(status)?;
Ok(XlaComputation(result))
}
pub fn first_error(&self) -> Result<()> {
let status = unsafe { c_lib::first_error(self.ptr()) };
handle_status(status)?;
Ok(())
}
pub fn get_current_status(&self) -> Result<()> {
let status = unsafe { c_lib::get_current_status(self.ptr()) };
handle_status(status)?;
Ok(())
}
pub fn constant_literal(&self, literal: &Literal) -> Result<XlaOp> {
let op = unsafe { c_lib::constant_literal(self.ptr(), literal.0) };
self.wrap(op)
}
pub fn constant_r0<T: NativeType>(&self, f: T) -> Result<XlaOp> {
let op = unsafe { T::constant_r0(self.ptr(), f) };
self.wrap(op)
}
pub fn c0<T: NativeType>(&self, f: T) -> Result<XlaOp> {
self.constant_r0(f)
}
pub fn wrap(&self, op: c_lib::xla_op) -> Result<XlaOp> {
self.get_current_status()?;
Ok(XlaOp { op, builder: self.clone() })
}
pub fn parameter(
&self,
parameter_number: i64,
ty: super::ElementType,
dims: &[i64],
name: &str,
) -> Result<XlaOp> {
let name = std::ffi::CString::new(name).unwrap();
let op = unsafe {
c_lib::parameter(
self.ptr(),
parameter_number,
ty.primitive_type() as i32,
dims.len() as i32,
dims.as_ptr(),
name.as_ptr(),
)
};
self.wrap(op)
}
pub fn infeed(&self, ty: PrimitiveType, dims: &[i64], config: &str) -> Result<XlaOp> {
let config = std::ffi::CString::new(config).unwrap();
let op = unsafe {
c_lib::infeed(self.ptr(), ty as i32, dims.len() as i32, dims.as_ptr(), config.as_ptr())
};
self.wrap(op)
}
pub fn parameter_s(&self, parameter_number: i64, shape: &Shape, name: &str) -> Result<XlaOp> {
let c_shape = shape.c_shape()?;
let name = std::ffi::CString::new(name).unwrap();
let op = unsafe {
c_lib::parameter_s(self.ptr(), parameter_number, c_shape.as_ptr(), name.as_ptr())
};
drop(c_shape);
self.wrap(op)
}
pub fn constant_r1c<T: NativeType>(&self, f: T, len: usize) -> Result<XlaOp> {
let op = unsafe { T::constant_r1c(self.ptr(), f, len) };
self.wrap(op)
}
pub fn constant_r1<T: NativeType>(&self, f: &[T]) -> Result<XlaOp> {
let op = unsafe { T::constant_r1(self.ptr(), f.as_ptr(), f.len()) };
self.wrap(op)
}
pub fn c1<T: NativeType>(&self, f: &[T]) -> Result<XlaOp> {
self.constant_r1(f)
}
pub fn zero(&self, ty: super::ElementType) -> Result<XlaOp> {
let op = unsafe { c_lib::op_zero(self.ptr(), ty.primitive_type() as i32) };
self.wrap(op)
}
pub fn one(&self, ty: super::ElementType) -> Result<XlaOp> {
let op = unsafe { c_lib::op_one(self.ptr(), ty.primitive_type() as i32) };
self.wrap(op)
}
pub fn min_value(&self, ty: super::ElementType) -> Result<XlaOp> {
let op = unsafe { c_lib::op_min_value(self.ptr(), ty.primitive_type() as i32) };
self.wrap(op)
}
pub fn max_value(&self, ty: super::ElementType) -> Result<XlaOp> {
let op = unsafe { c_lib::op_max_value(self.ptr(), ty.primitive_type() as i32) };
self.wrap(op)
}
pub fn iota(&self, ty: super::ElementType, dims: &[i64], iota_dimension: i64) -> Result<XlaOp> {
let op = unsafe {
c_lib::op_iota(
self.ptr(),
ty.primitive_type() as i32,
dims.len(),
dims.as_ptr(),
iota_dimension,
)
};
self.wrap(op)
}
pub fn iota1(&self, ty: super::ElementType, size: usize) -> Result<XlaOp> {
let op = unsafe { c_lib::op_iota1(self.ptr(), ty.primitive_type() as i32, size) };
self.wrap(op)
}
pub fn internal_error(&self, msg: &str) -> XlaOp {
let msg = std::ffi::CString::new(msg).unwrap();
let op = unsafe { c_lib::op_internal_error(self.ptr(), msg.as_ptr()) };
XlaOp { op, builder: self.clone() }
}
pub fn unknown_error(&self, msg: &str) -> XlaOp {
let msg = std::ffi::CString::new(msg).unwrap();
let op = unsafe { c_lib::op_unknown_error(self.ptr(), msg.as_ptr()) };
XlaOp { op, builder: self.clone() }
}
pub fn invalid_argument_error(&self, msg: &str) -> XlaOp {
let msg = std::ffi::CString::new(msg).unwrap();
let op = unsafe { c_lib::op_invalid_argument_error(self.ptr(), msg.as_ptr()) };
XlaOp { op, builder: self.clone() }
}
pub fn wrap_error(&self, op: Result<XlaOp>) -> XlaOp {
match op {
Ok(op) => op,
Err(err) => self.internal_error(&err.to_string()),
}
}
pub fn get_shape(&self, op: &XlaOp) -> Result<Shape> {
let mut out: c_lib::shape = std::ptr::null_mut();
let status = unsafe { c_lib::get_shape(self.ptr(), op.op, &mut out) };
handle_status(status)?;
let c_shape = super::shape::CShape::from_ptr(out);
c_shape.shape()
}
pub fn get_dims(&self, op: &XlaOp) -> Result<Vec<usize>> {
let rank = self.get_dimensions_size(op)?;
let mut dims = vec![0; rank];
let status = unsafe { c_lib::get_dimensions(self.ptr(), op.op, dims.as_mut_ptr()) };
handle_status(status)?;
Ok(dims)
}
pub fn get_primitive_type(&self, op: &XlaOp) -> Result<super::PrimitiveType> {
let mut ty = 0i32;
let status = unsafe { c_lib::get_element_type(self.ptr(), op.op, &mut ty) };
handle_status(status)?;
FromPrimitive::from_i32(ty).ok_or(Error::UnexpectedElementType(ty))
}
pub fn get_dimensions_size(&self, op: &XlaOp) -> Result<usize> {
let mut dsize = 0i32;
let status = unsafe { c_lib::get_dimensions_size(self.ptr(), op.op, &mut dsize) };
handle_status(status)?;
Ok(dsize as usize)
}
pub fn tuple<B: std::borrow::Borrow<XlaOp>>(&self, args: &[B]) -> Result<XlaOp> {
let args: Vec<_> = args.iter().map(|a| a.borrow().op).collect();
let op = unsafe { c_lib::op_tuple(self.ptr(), args.as_ptr(), args.len()) };
self.wrap(op)
}
}
impl Drop for XlaBuilderInternal {
fn drop(&mut self) {
unsafe { c_lib::xla_builder_free(self.0) }
}
}