use std::ffi::{CStr, CString};
use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
const MODEL_BASE_URL: &str =
"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131";
#[derive(Debug, Clone)]
pub struct UdpipeError {
pub message: String,
}
impl UdpipeError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for UdpipeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UDPipe error: {}", self.message)
}
}
impl std::error::Error for UdpipeError {}
impl From<std::io::Error> for UdpipeError {
fn from(err: std::io::Error) -> Self {
Self {
message: err.to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Word {
pub form: String,
pub lemma: String,
pub upostag: String,
pub xpostag: String,
pub feats: String,
pub deprel: String,
pub misc: String,
pub id: i32,
pub head: i32,
pub sentence_id: i32,
}
impl Word {
#[must_use]
pub fn has_feature(&self, key: &str, value: &str) -> bool {
self.get_feature(key) == Some(value)
}
#[must_use]
pub fn get_feature(&self, key: &str) -> Option<&str> {
self.feats
.split('|')
.find_map(|f| f.strip_prefix(key)?.strip_prefix('='))
}
#[must_use]
pub fn is_verb(&self) -> bool {
self.upostag == "VERB" || self.upostag == "AUX"
}
#[must_use]
pub fn is_noun(&self) -> bool {
self.upostag == "NOUN" || self.upostag == "PROPN"
}
#[must_use]
pub fn is_adjective(&self) -> bool {
self.upostag == "ADJ"
}
#[must_use]
pub fn is_punct(&self) -> bool {
self.upostag == "PUNCT"
}
#[must_use]
pub fn is_root(&self) -> bool {
self.deprel == "root"
}
#[must_use]
pub fn has_space_after(&self) -> bool {
!self.misc.contains("SpaceAfter=No")
}
}
mod ffi {
use std::os::raw::c_char;
#[repr(C)]
pub struct UdpipeModel {
_private: [u8; 0],
}
#[repr(C)]
pub struct UdpipeParseResult {
_private: [u8; 0],
}
#[repr(C)]
pub struct UdpipeWord {
pub form: *const c_char,
pub lemma: *const c_char,
pub upostag: *const c_char,
pub xpostag: *const c_char,
pub feats: *const c_char,
pub deprel: *const c_char,
pub misc: *const c_char,
pub id: i32,
pub head: i32,
pub sentence_id: i32,
}
unsafe extern "C" {
pub fn udpipe_model_load(model_path: *const c_char) -> *mut UdpipeModel;
pub fn udpipe_model_load_from_memory(data: *const u8, len: usize) -> *mut UdpipeModel;
pub fn udpipe_model_free(model: *mut UdpipeModel);
pub fn udpipe_parse(model: *mut UdpipeModel, text: *const c_char)
-> *mut UdpipeParseResult;
pub fn udpipe_result_free(result: *mut UdpipeParseResult);
pub fn udpipe_get_error() -> *const c_char;
pub fn udpipe_result_word_count(result: *mut UdpipeParseResult) -> i32;
pub fn udpipe_result_get_word(result: *mut UdpipeParseResult, index: i32) -> UdpipeWord;
}
}
fn get_ffi_error() -> String {
let err_ptr = unsafe { ffi::udpipe_get_error() };
assert!(!err_ptr.is_null(), "UDPipe returned null error pointer");
unsafe { CStr::from_ptr(err_ptr) }
.to_string_lossy()
.into_owned()
}
pub struct Model {
inner: *mut ffi::UdpipeModel,
}
impl std::fmt::Debug for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Model")
.field("inner", &(!self.inner.is_null()))
.finish()
}
}
unsafe impl Send for Model {}
impl Model {
pub fn load(path: impl AsRef<Path>) -> Result<Self, UdpipeError> {
let path_str = path.as_ref().to_string_lossy();
let c_path = CString::new(path_str.as_bytes()).map_err(|_| UdpipeError {
message: "Invalid path (contains null byte)".to_owned(),
})?;
let model = unsafe { ffi::udpipe_model_load(c_path.as_ptr()) };
if model.is_null() {
return Err(UdpipeError {
message: get_ffi_error(),
});
}
Ok(Self { inner: model })
}
pub fn load_from_memory(data: &[u8]) -> Result<Self, UdpipeError> {
let model = unsafe { ffi::udpipe_model_load_from_memory(data.as_ptr(), data.len()) };
if model.is_null() {
return Err(UdpipeError {
message: get_ffi_error(),
});
}
Ok(Self { inner: model })
}
pub fn parse(&self, text: &str) -> Result<Vec<Word>, UdpipeError> {
let c_text = CString::new(text).map_err(|_| UdpipeError {
message: "Invalid text (contains null byte)".to_owned(),
})?;
let result = unsafe { ffi::udpipe_parse(self.inner, c_text.as_ptr()) };
if result.is_null() {
return Err(UdpipeError {
message: get_ffi_error(),
});
}
let word_count = unsafe { ffi::udpipe_result_word_count(result) };
let capacity = usize::try_from(word_count).unwrap_or(0);
let mut words = Vec::with_capacity(capacity);
for i in 0..word_count {
let word = unsafe { ffi::udpipe_result_get_word(result, i) };
words.push(Word {
form: ptr_to_string(word.form),
lemma: ptr_to_string(word.lemma),
upostag: ptr_to_string(word.upostag),
xpostag: ptr_to_string(word.xpostag),
feats: ptr_to_string(word.feats),
deprel: ptr_to_string(word.deprel),
misc: ptr_to_string(word.misc),
id: word.id,
head: word.head,
sentence_id: word.sentence_id,
});
}
unsafe { ffi::udpipe_result_free(result) };
Ok(words)
}
}
fn ptr_to_string(ptr: *const std::os::raw::c_char) -> String {
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
}
impl Drop for Model {
fn drop(&mut self) {
if !self.inner.is_null() {
unsafe { ffi::udpipe_model_free(self.inner) };
}
}
}
pub const AVAILABLE_MODELS: &[&str] = &[
"afrikaans-afribooms",
"ancient_greek-perseus",
"ancient_greek-proiel",
"arabic-padt",
"armenian-armtdp",
"basque-bdt",
"belarusian-hse",
"bulgarian-btb",
"buryat-bdt",
"catalan-ancora",
"chinese-gsd",
"chinese-gsdsimp",
"classical_chinese-kyoto",
"coptic-scriptorium",
"croatian-set",
"czech-cac",
"czech-cltt",
"czech-fictree",
"czech-pdt",
"danish-ddt",
"dutch-alpino",
"dutch-lassysmall",
"english-ewt",
"english-gum",
"english-lines",
"english-partut",
"estonian-edt",
"estonian-ewt",
"finnish-ftb",
"finnish-tdt",
"french-gsd",
"french-partut",
"french-sequoia",
"french-spoken",
"galician-ctg",
"galician-treegal",
"german-gsd",
"german-hdt",
"gothic-proiel",
"greek-gdt",
"hebrew-htb",
"hindi-hdtb",
"hungarian-szeged",
"indonesian-gsd",
"irish-idt",
"italian-isdt",
"italian-partut",
"italian-postwita",
"italian-twittiro",
"italian-vit",
"japanese-gsd",
"kazakh-ktb",
"korean-gsd",
"korean-kaist",
"kurmanji-mg",
"latin-ittb",
"latin-perseus",
"latin-proiel",
"latvian-lvtb",
"lithuanian-alksnis",
"lithuanian-hse",
"maltese-mudt",
"marathi-ufal",
"north_sami-giella",
"norwegian-bokmaal",
"norwegian-nynorsk",
"norwegian-nynorsklia",
"old_church_slavonic-proiel",
"old_french-srcmf",
"old_russian-torot",
"persian-seraji",
"polish-lfg",
"polish-pdb",
"polish-sz",
"portuguese-bosque",
"portuguese-br",
"portuguese-gsd",
"romanian-nonstandard",
"romanian-rrt",
"russian-gsd",
"russian-syntagrus",
"russian-taiga",
"sanskrit-ufal",
"scottish_gaelic-arcosg",
"serbian-set",
"slovak-snk",
"slovenian-ssj",
"slovenian-sst",
"spanish-ancora",
"spanish-gsd",
"swedish-lines",
"swedish-talbanken",
"tamil-ttb",
"telugu-mtg",
"turkish-imst",
"ukrainian-iu",
"upper_sorbian-ufal",
"urdu-udtb",
"uyghur-udt",
"vietnamese-vtb",
"wolof-wtb",
];
pub fn download_model(language: &str, dest_dir: impl AsRef<Path>) -> Result<String, UdpipeError> {
let dest_dir = dest_dir.as_ref();
if !AVAILABLE_MODELS.contains(&language) {
return Err(UdpipeError {
message: format!(
"Unknown language '{}'. Use one of: {}",
language,
AVAILABLE_MODELS[..5].join(", ") + ", ..."
),
});
}
let filename = model_filename(language);
let dest_path = dest_dir.join(&filename);
let url = format!("{MODEL_BASE_URL}/{filename}");
download_model_from_url(&url, &dest_path)?;
Ok(dest_path.to_string_lossy().into_owned())
}
pub fn download_model_from_url(url: &str, path: impl AsRef<Path>) -> Result<(), UdpipeError> {
let path = path.as_ref();
let response = ureq::get(url).call().map_err(|e| UdpipeError {
message: format!("Failed to download: {e}"),
})?;
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
let bytes_written = std::io::copy(&mut response.into_body().into_reader(), &mut writer)?;
if bytes_written == 0 {
return Err(UdpipeError {
message: "Downloaded file is empty".to_owned(),
});
}
Ok(())
}
#[must_use]
pub fn model_filename(language: &str) -> String {
format!("{language}-ud-2.5-191206.udpipe")
}
#[cfg(test)]
mod tests {
use super::*;
fn make_word(feats: &str) -> Word {
Word {
form: "test".to_owned(),
lemma: "test".to_owned(),
upostag: "NOUN".to_owned(),
xpostag: String::new(),
feats: feats.to_owned(),
deprel: "root".to_owned(),
misc: String::new(),
id: 1,
head: 0,
sentence_id: 0,
}
}
#[test]
fn test_word_has_feature() {
let word = make_word("Mood=Imp|VerbForm=Fin");
assert!(word.has_feature("Mood", "Imp"));
assert!(word.has_feature("VerbForm", "Fin"));
assert!(!word.has_feature("Mood", "Ind"));
assert!(!word.has_feature("Tense", "Past"));
}
#[test]
fn test_word_has_feature_empty() {
let word = make_word("");
assert!(!word.has_feature("Mood", "Imp"));
}
#[test]
fn test_word_has_feature_single() {
let word = make_word("Mood=Imp");
assert!(word.has_feature("Mood", "Imp"));
assert!(!word.has_feature("VerbForm", "Fin"));
}
#[test]
fn test_word_get_feature() {
let word = make_word("Tense=Pres|VerbForm=Part");
assert_eq!(word.get_feature("Tense"), Some("Pres"));
assert_eq!(word.get_feature("VerbForm"), Some("Part"));
assert_eq!(word.get_feature("Mood"), None);
}
#[test]
fn test_word_get_feature_empty() {
let word = make_word("");
assert_eq!(word.get_feature("Mood"), None);
}
#[test]
fn test_word_get_feature_single() {
let word = make_word("Mood=Imp");
assert_eq!(word.get_feature("Mood"), Some("Imp"));
assert_eq!(word.get_feature("VerbForm"), None);
}
#[test]
fn test_word_is_verb() {
let mut word = make_word("");
word.upostag = "VERB".to_owned();
assert!(word.is_verb());
word.upostag = "AUX".to_owned();
assert!(word.is_verb());
word.upostag = "NOUN".to_owned();
assert!(!word.is_verb());
}
#[test]
fn test_word_is_noun() {
let mut word = make_word("");
word.upostag = "NOUN".to_owned();
assert!(word.is_noun());
word.upostag = "PROPN".to_owned();
assert!(word.is_noun());
word.upostag = "VERB".to_owned();
assert!(!word.is_noun());
}
#[test]
fn test_word_is_root() {
let mut word = make_word("");
word.deprel = "root".to_owned();
assert!(word.is_root());
word.deprel = "nsubj".to_owned();
assert!(!word.is_root());
}
#[test]
fn test_word_is_adjective() {
let mut word = make_word("");
word.upostag = "ADJ".to_owned();
assert!(word.is_adjective());
word.upostag = "NOUN".to_owned();
assert!(!word.is_adjective());
}
#[test]
fn test_word_is_punct() {
let mut word = make_word("");
word.upostag = "PUNCT".to_owned();
assert!(word.is_punct());
word.upostag = "NOUN".to_owned();
assert!(!word.is_punct());
}
#[test]
fn test_word_hash() {
use std::collections::HashSet;
let word1 = make_word("Mood=Imp");
let word2 = make_word("Mood=Imp");
let mut set = HashSet::new();
set.insert(word1);
assert!(set.contains(&word2));
}
#[test]
fn test_model_filename() {
assert_eq!(
model_filename("english-ewt"),
"english-ewt-ud-2.5-191206.udpipe"
);
assert_eq!(
model_filename("dutch-alpino"),
"dutch-alpino-ud-2.5-191206.udpipe"
);
}
#[test]
fn test_available_models_contains_common_languages() {
assert!(AVAILABLE_MODELS.contains(&"english-ewt"));
assert!(AVAILABLE_MODELS.contains(&"german-gsd"));
assert!(AVAILABLE_MODELS.contains(&"french-gsd"));
assert!(AVAILABLE_MODELS.contains(&"spanish-ancora"));
}
#[test]
fn test_available_models_sorted() {
let mut sorted = AVAILABLE_MODELS.to_vec();
sorted.sort_unstable();
assert_eq!(AVAILABLE_MODELS, sorted.as_slice());
}
#[test]
fn test_download_model_invalid_language() {
let result = download_model("invalid-language-xyz", ".");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("Unknown language"));
}
#[test]
fn test_udpipe_error_display() {
let err = UdpipeError::new("test error");
assert_eq!(format!("{err}"), "UDPipe error: test error");
}
#[test]
fn test_udpipe_error_from_io() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err: UdpipeError = io_err.into();
assert!(err.message.contains("not found"));
}
#[test]
fn test_has_space_after() {
let mut word = make_word("");
word.misc = String::new();
assert!(word.has_space_after());
word.misc = "SpaceAfter=No".to_owned();
assert!(!word.has_space_after());
word.misc = "SpaceAfter=No|Other=Value".to_owned();
assert!(!word.has_space_after());
}
#[test]
fn test_model_load_nonexistent_file() {
let result = Model::load("/nonexistent/path/to/model.udpipe");
assert!(result.is_err());
}
#[test]
fn test_model_load_path_with_null_byte() {
let result = Model::load("path\0with\0nulls.udpipe");
let err = result.expect_err("expected error");
assert!(err.message.contains("null byte"));
}
#[test]
fn test_model_load_from_memory_empty() {
let result = Model::load_from_memory(&[]);
assert!(result.is_err());
}
#[test]
fn test_model_load_from_memory_invalid() {
let garbage = b"this is not a valid udpipe model";
let result = Model::load_from_memory(garbage);
assert!(result.is_err());
}
#[test]
fn test_parse_with_null_model() {
let model = Model {
inner: std::ptr::null_mut(),
};
let result = model.parse("test");
let err = result.unwrap_err();
assert!(err.message.contains("Invalid arguments"));
}
#[test]
fn test_model_debug() {
let model = Model {
inner: std::ptr::null_mut(),
};
let debug_str = format!("{model:?}");
assert!(debug_str.contains("Model"));
assert!(debug_str.contains("inner"));
}
#[test]
fn test_download_model_from_url_invalid_url() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("model.udpipe");
let result = download_model_from_url("http://invalid.invalid/no-such-model", &path);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("Failed to download"));
}
#[test]
fn test_download_model_from_url_nonexistent_dir() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("nonexistent/model.udpipe");
let url = "http://localhost:1/model.udpipe";
let result = download_model_from_url(url, &path);
assert!(result.is_err());
}
#[test]
fn test_download_model_from_url_empty_response() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("model.udpipe");
let mut server = mockito::Server::new();
let mock = server
.mock("GET", "/empty-model.udpipe")
.with_status(200)
.with_body("")
.create();
let full_url = format!("{}/empty-model.udpipe", server.url());
let result = download_model_from_url(&full_url, &path);
mock.assert();
drop(server);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("empty"));
}
#[test]
fn test_ffi_null_result_word_count() {
let count = unsafe { ffi::udpipe_result_word_count(std::ptr::null_mut()) };
assert_eq!(count, 0);
}
#[test]
fn test_ffi_null_result_get_word() {
let word = unsafe { ffi::udpipe_result_get_word(std::ptr::null_mut(), 0) };
assert!(word.form.is_null());
assert!(word.lemma.is_null());
assert!(word.upostag.is_null());
}
#[test]
fn test_ffi_invalid_index() {
let word = unsafe { ffi::udpipe_result_get_word(std::ptr::null_mut(), -1) };
assert!(word.form.is_null());
}
}