import onnxruntime as ort
import numpy as np
import soundfile as sf
import time
import os
import torchaudio
import torch
def load_tokens(path):
tokens = {}
with open(path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
token = parts[0]
idx = int(parts[1])
token = token.replace('▁', ' ')
tokens[idx] = token
return tokens
def greedy_decode(logits, tokens):
indices = np.argmax(logits[0], axis=-1)
res = []
prev = -1
for idx in indices:
if idx != 0 and idx != prev:
res.append(tokens.get(idx, f"<{idx}>"))
prev = idx
return "".join(res)
def main():
wav_path = 'fixtures/zh.wav'
model_path = 'examples/sensevoice/sensevoice.int8.onnx'
tokens_path = 'examples/sensevoice/sensevoice.int8.tokens.txt'
if not os.path.exists(wav_path):
print(f"Error: {wav_path} not found")
return
audio, sr = sf.read(wav_path)
if sr != 16000:
pass
duration = len(audio) / sr
print(f"Audio file: {wav_path}")
print(f"Audio duration: {duration:.2f}s")
print("Extracting features (Mel + LFR)...")
start_feat = time.perf_counter()
waveform = torch.from_numpy(audio).float().unsqueeze(0)
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=16000,
n_fft=400,
win_length=400,
hop_length=160,
n_mels=80,
center=False )(waveform)
mel = torch.log(mel_spec + 1e-6)
mel = mel.squeeze(0).transpose(0, 1).numpy()
def compute_lfr(features, m=7, n=6):
T, D = features.shape
lfr_features = []
for i in range(0, T, n):
start = i - (m // 2)
frames = []
for j in range(start, start + m):
if j < 0:
frames.append(features[0])
elif j >= T:
frames.append(features[T-1])
else:
frames.append(features[j])
lfr_features.append(np.concatenate(frames))
return np.array(lfr_features)
lfr_feats = compute_lfr(mel)
lfr_feats = (lfr_feats - np.mean(lfr_feats, axis=0)) / (np.std(lfr_feats, axis=0) + 1e-6)
feat_time = (time.perf_counter() - start_feat) * 1000
print(f"✓ Features extracted: {lfr_feats.shape}, took {feat_time:.2f}ms")
print("Initializing ONNX Runtime...")
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.inter_op_num_threads = 1
sess = ort.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider'])
T_lfr = lfr_feats.shape[0]
inputs = {
'x': lfr_feats[np.newaxis, ...].astype(np.float32),
'x_length': np.array([T_lfr], dtype=np.int32),
'language': np.array([3], dtype=np.int32), 'text_norm': np.array([0], dtype=np.int32)
}
print("Running model inference...")
_ = sess.run(None, inputs)
start_inf = time.perf_counter()
outputs = sess.run(None, inputs)
inf_time = (time.perf_counter() - start_inf) * 1000
tokens = load_tokens(tokens_path)
text = greedy_decode(outputs[0], tokens)
print(f"\n=== SenseVoice ORT Results ===")
print(f"Result: {text}")
print(f"Inference: {inf_time:.2f} ms")
print(f"Model RTF: {inf_time/1000 / duration:.4f}")
total_time = feat_time + inf_time
print(f"Total RTF: {(total_time/1000) / duration:.4f}")
if __name__ == "__main__":
main()