import sys,traceback
from ctypes import *
try:
import numpy as np
numpyflag = 1
except:
numpyflag = 0
try:
from mpi4py import MPI
mpi4pyflag = 1
except:
mpi4pyflag = 0
class CSlib:
def __init__(self,csflag,mode,ptr,comm):
try:
if comm: self.lib = CDLL("libcsmpi.so",RTLD_GLOBAL)
else: self.lib = CDLL("libcsnompi.so",RTLD_GLOBAL)
except:
etype,value,tb = sys.exc_info()
traceback.print_exception(etype,value,tb)
raise OSError,"Could not load CSlib dynamic library"
self.lib.cslib_open.argtypes = [c_int,c_char_p,c_void_p,c_void_p,
POINTER(c_void_p)]
self.lib.cslib_open.restype = None
self.lib.cslib_close.argtypes = [c_void_p]
self.lib.cslib_close.restype = None
self.lib.cslib_send.argtypes = [c_void_p,c_int,c_int]
self.lib.cslib_send.restype = None
self.lib.cslib_pack_int.argtypes = [c_void_p,c_int,c_int]
self.lib.cslib_pack_int.restype = None
self.lib.cslib_pack_int64.argtypes = [c_void_p,c_int,c_longlong]
self.lib.cslib_pack_int64.restype = None
self.lib.cslib_pack_float.argtypes = [c_void_p,c_int,c_float]
self.lib.cslib_pack_float.restype = None
self.lib.cslib_pack_double.argtypes = [c_void_p,c_int,c_double]
self.lib.cslib_pack_double.restype = None
self.lib.cslib_pack_string.argtypes = [c_void_p,c_int,c_char_p]
self.lib.cslib_pack_string.restype = None
self.lib.cslib_pack.argtypes = [c_void_p,c_int,c_int,c_int,c_void_p]
self.lib.cslib_pack.restype = None
self.lib.cslib_pack_parallel.argtypes = [c_void_p,c_int,c_int,c_int,
POINTER(c_int),c_int,c_void_p]
self.lib.cslib_pack_parallel.restype = None
self.lib.cslib_recv.argtypes = [c_void_p,POINTER(c_int),
POINTER(POINTER(c_int)),
POINTER(POINTER(c_int)),
POINTER(POINTER(c_int))]
self.lib.cslib_recv.restype = c_int
self.lib.cslib_unpack_int.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack_int.restype = c_int
self.lib.cslib_unpack_int64.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack_int64.restype = c_longlong
self.lib.cslib_unpack_float.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack_float.restype = c_float
self.lib.cslib_unpack_double.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack_double.restype = c_double
self.lib.cslib_unpack_string.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack_string.restype = c_char_p
self.lib.cslib_unpack.argtypes = [c_void_p,c_int]
self.lib.cslib_unpack.restype = c_void_p
self.lib.cslib_unpack_data.argtypes = [c_void_p,c_int,c_void_p]
self.lib.cslib_unpack_data.restype = None
self.lib.cslib_unpack_parallel.argtypes = [c_void_p,c_int,c_int,
POINTER(c_int),c_int,c_void_p]
self.lib.cslib_unpack_parallel.restype = None
self.lib.cslib_extract.argtypes = [c_void_p,c_int]
self.lib.cslib_extract.restype = c_int
self.cs = c_void_p()
if not comm:
self.lib.cslib_open(csflag,mode,ptr,None,byref(self.cs))
elif not mpi4pyflag:
print "Cannot pass MPI communicator to CSlib w/out mpi4py package"
sys.exit()
else:
address = MPI._addressof(comm)
comm_ptr = c_void_p(address)
if mode == "mpi/one":
address = MPI._addressof(ptr)
ptrcopy = c_void_p(address)
else: ptrcopy = ptr
self.lib.cslib_open(csflag,mode,ptrcopy,comm_ptr,byref(self.cs))
def __del__(self):
if self.cs: self.lib.cslib_close(self.cs)
def close(self):
self.lib.cslib_close(self.cs)
self.lib = None
def send(self,msgID,nfield):
self.nfield = nfield
self.lib.cslib_send(self.cs,msgID,nfield)
def pack_int(self,id,value):
self.lib.cslib_pack_int(self.cs,id,value)
def pack_int64(self,id,value):
self.lib.cslib_pack_int64(self.cs,id,value)
def pack_float(self,id,value):
self.lib.cslib_pack_float(self.cs,id,value)
def pack_double(self,id,value):
self.lib.cslib_pack_double(self.cs,id,value)
def pack_string(self,id,value):
self.lib.cslib_pack_string(self.cs,id,value)
def pack(self,id,ftype,flen,data):
cdata = self.data_convert(ftype,flen,data)
self.lib.cslib_pack(self.cs,id,ftype,flen,cdata)
def pack_parallel(self,id,ftype,nlocal,ids,nper,data):
cids = self.data_convert(1,nlocal,ids)
cdata = self.data_convert(ftype,nper*nlocal,data)
self.lib.cslib_pack_parallel(self.cs,id,ftype,nlocal,cids,nper,cdata)
def data_convert(self,ftype,flen,data):
txttype = str(type(data))
if "numpy" in txttype: tflag = 2
elif "c_" in txttype: tflag = 3
else: tflag = 1
if ftype == 1:
if tflag == 1: cdata = (flen * c_int)(*data)
elif tflag == 2: cdata = data.ctypes.data_as(POINTER(c_int))
elif tflag == 3: cdata = data
elif ftype == 2:
if tflag == 1: cdata = (flen * c_longlong)(*data)
elif tflag == 2: cdata = data.ctypes.data_as(POINTER(c_longlong))
elif tflag == 3: cdata = data
elif ftype == 3:
if tflag == 1: cdata = (flen * c_float)(*data)
elif tflag == 2: cdata = data.ctypes.data_as(POINTER(c_float))
elif tflag == 3: cdata = data
elif ftype == 4:
if tflag == 1: cdata = (flen * c_double)(*data)
elif tflag == 2: cdata = data.ctypes.data_as(POINTER(c_double))
elif tflag == 3: cdata = data
return cdata
def recv(self):
self.lib.cslib_recv.restype = c_int
nfield = c_int()
fieldID = POINTER(c_int)()
fieldtype = POINTER(c_int)()
fieldlen = POINTER(c_int)()
msgID = self.lib.cslib_recv(self.cs,byref(nfield),
byref(fieldID),byref(fieldtype),byref(fieldlen))
self.nfield = nfield = nfield.value
self.fieldID = fieldID[:nfield]
self.fieldtype = fieldtype[:nfield]
self.fieldlen = fieldlen[:nfield]
return msgID,self.nfield,self.fieldID,self.fieldtype,self.fieldlen
def unpack_int(self,id):
return self.lib.cslib_unpack_int(self.cs,id)
def unpack_int64(self,id):
return self.lib.cslib_unpack_int64(self.cs,id)
def unpack_float(self,id):
return self.lib.cslib_unpack_float(self.cs,id)
def unpack_double(self,id):
return self.lib.cslib_unpack_double(self.cs,id)
def unpack_string(self,id):
return self.lib.cslib_unpack_string(self.cs,id)
def unpack(self,id,tflag=3):
index = self.fieldID.index(id)
if self.fieldtype[index] == 1:
self.lib.cslib_unpack.restype = POINTER(c_int)
elif self.fieldtype[index] == 2:
self.lib.cslib_unpack.restype = POINTER(c_longlong)
elif self.fieldtype[index] == 3:
self.lib.cslib_unpack.restype = POINTER(c_float)
elif self.fieldtype[index] == 4:
self.lib.cslib_unpack.restype = POINTER(c_double)
cdata = self.lib.cslib_unpack(self.cs,id)
if tflag == 1:
data = cdata[:self.fieldlen[index]]
elif tflag == 2:
if numpyflag == 0:
print "Cannot return Numpy array w/out numpy package"
sys.exit()
data = np.ctypeslib.as_array(cdata,shape=(self.fieldlen[index],))
elif tflag == 3:
data = cdata
return data
def unpack_data(self,id,tflag=3):
index = self.fieldID.index(id)
def unpack_parallel(self,id,nlocal,ids,nper,tflag=3):
cids = self.data_convert(1,nlocal,ids)
index = self.fieldID.index(id)
if self.fieldtype[index] == 1: cdata = (nper*nlocal * c_int)()
elif self.fieldtype[index] == 2: cdata = (nlocal*nper * c_longlong)()
elif self.fieldtype[index] == 3: cdata = (nlocal*nper * c_float)()
elif self.fieldtype[index] == 4: cdata = (nlocal*nper * c_double)()
self.lib.cslib_unpack_parallel(self.cs,id,nlocal,cids,nper,cdata)
if tflag == 1:
data = cdata[:nper*nlocal]
elif tflag == 2:
if numpyflag == 0:
print "Cannot return Numpy array w/out numpy package"
sys.exit()
data = np.ctypeslib.as_array(cdata,shape=(nlocal*nper,))
elif tflag == 3:
data = cdata
return data
def extract(self,flag):
return self.lib.cslib_extract(self.cs,flag)