1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
use crate::{
detections::Detections,
error::Error,
image::IntoCowImage,
layers::{Layer, Layers},
utils,
};
use darknet_sys as sys;
use std::{
ffi::c_void,
os::raw::c_int,
path::Path,
ptr::{self, NonNull},
slice,
};
/// The network wrapper type for Darknet.
pub struct Network {
net: NonNull<sys::network>,
}
impl Network {
/// Build the network instance from a configuration file and an optional weights file.
///
/// This will abort the program with an exit code of 1 if any of the following occur.
/// - The config has no sections.
/// - The first section of the config is not `[net]` or `[network]`.
/// - `fopen` fails on \[weights\] (if provided).
/// - The weights file is invalid
///
/// Returns an [Err] if [cfg] or \[weights\] (if provided) contain a null byte.
pub fn load<C, W>(cfg: C, weights: Option<W>, clear: bool) -> Result<Network, Error>
where
C: AsRef<Path>,
W: AsRef<Path>,
{
let weights_cstr = weights
.map(|path| utils::path_to_cstring_or_error(path.as_ref()))
.transpose()?;
let cfg_cstr = utils::path_to_cstring_or_error(cfg.as_ref())?;
let clear = c_int::from(clear);
let ptr = unsafe {
let raw_weights = weights_cstr
.as_ref()
.map_or(ptr::null_mut(), |cstr| cstr.as_ptr() as *mut _);
let raw_cfg = cfg_cstr.as_ptr() as *mut _;
sys::load_network(raw_cfg, raw_weights, clear)
};
let net = NonNull::new(ptr).ok_or_else(|| Error::InternalError {
reason: "failed to load model".into(),
})?;
// drop paths here to avoid early deallocation
drop(cfg_cstr);
drop(weights_cstr);
Ok(Self { net })
}
/// Get network input width.
pub fn input_width(&self) -> usize {
unsafe { self.net.as_ref().w as usize }
}
/// Get network input height.
pub fn input_height(&self) -> usize {
unsafe { self.net.as_ref().h as usize }
}
/// Get network input channels.
pub fn input_channels(&self) -> usize {
unsafe { self.net.as_ref().c as usize }
}
/// Get network input shape tuple (channels, height, width).
pub fn input_shape(&self) -> (usize, usize, usize) {
(
self.input_channels(),
self.input_height(),
self.input_width(),
)
}
/// Get the number of layers.
pub fn num_layers(&self) -> usize {
unsafe { self.net.as_ref().n as usize }
}
/// Get network layers.
pub fn layers(&self) -> Layers {
let layers = unsafe { slice::from_raw_parts(self.net.as_ref().layers, self.num_layers()) };
Layers { layers }
}
/// Get layer by index.
pub fn get_layer(&self, index: usize) -> Option<Layer> {
if index >= self.num_layers() {
return None;
}
unsafe {
let layer = self.net.as_ref().layers.add(index).as_ref().unwrap();
Some(Layer { layer })
}
}
/// Run inference on an image.
pub fn predict<'a, M>(
&mut self,
image: M,
thresh: f32,
hier_thres: f32,
nms: f32,
use_letter_box: bool,
) -> Detections
where
M: IntoCowImage<'a>,
{
let cow = image.into_cow_image();
unsafe {
let output_layer = self
.net
.as_ref()
.layers
.add(self.num_layers() - 1)
.as_ref()
.unwrap();
// run prediction
if use_letter_box {
sys::network_predict_image_letterbox(self.net.as_ptr(), cow.image);
} else {
sys::network_predict_image(self.net.as_ptr(), cow.image);
}
let mut nboxes: c_int = 0;
let dets_ptr = sys::get_network_boxes(
self.net.as_mut(),
cow.width() as c_int,
cow.height() as c_int,
thresh,
hier_thres,
ptr::null_mut(),
1,
&mut nboxes,
use_letter_box as c_int,
);
let dets = NonNull::new(dets_ptr).unwrap();
// NMS sort
if nms != 0.0 {
if output_layer.nms_kind == sys::NMS_KIND_DEFAULT_NMS {
sys::do_nms_sort(dets.as_ptr(), nboxes, output_layer.classes, nms);
} else {
sys::diounms_sort(
dets.as_ptr(),
nboxes,
output_layer.classes,
nms,
output_layer.nms_kind,
output_layer.beta_nms,
);
}
}
Detections {
detections: dets,
n_detections: nboxes as usize,
}
}
}
}
impl Drop for Network {
fn drop(&mut self) {
unsafe {
let ptr = self.net.as_ptr();
sys::free_network(*ptr);
// The network* pointer was allocated by calloc
// We have to deallocate it manually
libc::free(ptr as *mut c_void);
}
}
}
unsafe impl Send for Network {}