use std::iter;
use bitfield::bitfield;
use bitvec::prelude::*;
use probe_rs_target::ScanChainElement;
use crate::probe::{
AutoImplementJtagAccess, BatchExecutionError, ChainParams, CommandQueue, CommandResult,
DebugProbeError, DeferredResultSet, JtagAccess, JtagCommand, JtagSequence, RawJtagIo,
};
pub(crate) fn bits_to_byte(bits: impl IntoIterator<Item = bool>) -> u32 {
let mut bit_val = 0u32;
for (index, bit) in bits.into_iter().take(32).enumerate() {
if bit {
bit_val |= 1 << index;
}
}
bit_val
}
bitfield! {
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct IdCode(u32);
impl Debug;
u8;
pub version, set_version: 31, 28;
u16;
pub part_number, set_part_number: 27, 12;
pub manufacturer, set_manufacturer: 11, 1;
u8;
pub manufacturer_continuation, set_manufacturer_continuation: 11, 8;
pub manufacturer_identity, set_manufacturer_identity: 7, 1;
bool;
pub lsbit, set_lsbit: 0;
}
impl std::fmt::Display for IdCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(mfn) = self.manufacturer_name() {
write!(f, "0x{:08X} ({})", self.0, mfn)
} else {
write!(f, "0x{:08X}", self.0)
}
}
}
impl IdCode {
pub fn valid(&self) -> bool {
self.lsbit() && (self.manufacturer() != 0) && (self.manufacturer() != 127)
}
pub fn manufacturer_name(&self) -> Option<&'static str> {
let cc = self.manufacturer_continuation();
let id = self.manufacturer_identity();
jep106::JEP106Code::new(cc, id).get()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ScanChainError {
#[error("Invalid IDCODE")]
InvalidIdCode,
#[error("Invalid IR scan chain")]
InvalidIR,
}
fn starts_to_lengths(starts: &[usize], total: usize) -> Vec<usize> {
let mut lens: Vec<usize> = starts.windows(2).map(|w| w[1] - w[0]).collect();
lens.push(total - lens.iter().sum::<usize>());
lens
}
pub(crate) fn extract_idcodes<T: BitStore>(
mut dr: &BitSlice<T>,
) -> Result<Vec<Option<IdCode>>, ScanChainError> {
let mut idcodes = Vec::new();
let mut accumulated_bypass_taps = 0;
while !dr.is_empty() {
if dr[0] {
if dr.len() < 32 {
tracing::error!("Truncated IDCODE: {dr:02X?}");
return Err(ScanChainError::InvalidIdCode);
}
let idcode = dr[0..32].load_le::<u32>();
if idcode == u32::MAX {
break;
}
let idcode = IdCode(idcode);
if !idcode.valid() {
tracing::error!("Invalid IDCODE: {:08X}", idcode.0);
return Err(ScanChainError::InvalidIdCode);
}
if accumulated_bypass_taps != 0 {
tracing::info!("Appending {accumulated_bypass_taps} bypass taps");
for _ in 0..accumulated_bypass_taps {
idcodes.push(None);
}
accumulated_bypass_taps = 0;
}
tracing::info!("Found IDCODE: {idcode}");
idcodes.push(Some(idcode));
dr = &dr[32..];
} else {
accumulated_bypass_taps += 1;
dr = &dr[1..];
}
}
Ok(idcodes)
}
pub(crate) fn common_sequence<'a, S: BitStore>(
a: &'a BitSlice<S>,
b: &BitSlice<S>,
) -> &'a BitSlice<S> {
let common_length = a.iter().zip(b.iter()).take_while(|(a, b)| *a == *b).count();
&a[..common_length]
}
pub(crate) fn extract_ir_lengths<T: BitStore>(
ir: &BitSlice<T>,
n_taps: usize,
expected: Option<&[usize]>,
) -> Result<Vec<usize>, ScanChainError> {
let starts = ir
.windows(2)
.enumerate()
.filter(|(_, w)| w[0] && !w[1])
.map(|(i, _)| i)
.collect::<Vec<usize>>();
tracing::trace!("Possible IR start positions: {starts:?}");
if n_taps == 0 {
tracing::error!("Cannot scan IR without at least one TAP");
Err(ScanChainError::InvalidIR)
} else if n_taps > starts.len() {
tracing::error!("Fewer IRs detected than TAPs");
Err(ScanChainError::InvalidIR)
} else if starts[0] != 0 {
tracing::error!("IR chain does not begin with a valid start pattern");
Err(ScanChainError::InvalidIR)
} else if let Some(expected) = expected {
if expected.len() != n_taps {
tracing::error!(
"Number of provided IR lengths ({}) does not match \
number of detected TAPs ({n_taps})",
expected.len()
);
Err(ScanChainError::InvalidIR)
} else if expected.iter().sum::<usize>() != ir.len() {
tracing::error!(
"Sum of provided IR lengths ({}) does not match \
length of IR scan ({} bits)",
expected.iter().sum::<usize>(),
ir.len()
);
Err(ScanChainError::InvalidIR)
} else {
let exp_starts = expected
.iter()
.scan(0, |a, &x| {
let b = *a;
*a += x;
Some(b)
})
.collect::<Vec<usize>>();
tracing::trace!("Provided IR start positions: {exp_starts:?}");
let unsupported = exp_starts.iter().filter(|s| !starts.contains(s)).count();
if unsupported > 0 {
tracing::error!(
"Provided IR lengths imply an IR start position \
which is not supported by the IR scan"
);
Err(ScanChainError::InvalidIR)
} else {
tracing::debug!("Verified provided IR lengths against IR scan");
Ok(starts_to_lengths(&exp_starts, ir.len()))
}
}
} else if n_taps == 1 {
tracing::info!("Only one TAP detected, IR length {}", ir.len());
Ok(vec![ir.len()])
} else if n_taps == starts.len() {
let irlens = starts_to_lengths(&starts, ir.len());
tracing::info!("IR lengths are unambiguous: {irlens:?}");
Ok(irlens)
} else {
if n_taps < starts.len() {
let mut irlens = starts_to_lengths(&starts, ir.len()).into_iter();
let mut merged = Vec::new();
while let Some(len) = irlens.next() {
if len == 2
&& let Some(next) = irlens.next()
{
merged.push(len + next);
continue;
}
merged.push(len);
}
if merged.len() == n_taps {
tracing::info!("IR lengths after merging 101xx prefixes: {merged:?}");
return Ok(merged);
}
}
tracing::error!("IR lengths are ambiguous and must be explicitly configured.");
Err(ScanChainError::InvalidIR)
}
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub(crate) enum RegisterState {
Select,
Capture,
Shift,
Exit1,
Pause,
Exit2,
Update,
}
impl RegisterState {
fn step_toward(self, target: Self) -> bool {
match self {
Self::Select => false,
Self::Capture if matches!(target, Self::Shift) => false,
Self::Exit1 if matches!(target, Self::Pause | Self::Exit2) => false,
Self::Exit2 if matches!(target, Self::Shift | Self::Exit1 | Self::Pause) => false,
Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
_ => true,
}
}
fn update(self, tms: bool) -> Self {
if tms {
match self {
Self::Capture | Self::Shift => Self::Exit1,
Self::Exit1 | Self::Exit2 => Self::Update,
Self::Pause => Self::Exit2,
Self::Select | Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
}
} else {
match self {
Self::Select => Self::Capture,
Self::Capture | Self::Shift => Self::Shift,
Self::Exit1 | Self::Pause => Self::Pause,
Self::Exit2 => Self::Shift,
Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
}
}
}
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub(crate) enum JtagState {
Reset,
Idle,
Dr(RegisterState),
Ir(RegisterState),
}
impl JtagState {
pub fn step_toward(self, target: Self) -> Option<bool> {
let tms = match self {
state if target == state => return None,
Self::Reset => false,
Self::Idle => true,
Self::Dr(RegisterState::Select) => !matches!(target, Self::Dr(_)),
Self::Ir(RegisterState::Select) => !matches!(target, Self::Ir(_)),
Self::Dr(RegisterState::Update) | Self::Ir(RegisterState::Update) => {
matches!(target, Self::Ir(_) | Self::Dr(_))
}
Self::Dr(state) => {
let next = if let Self::Dr(target) = target {
target
} else {
RegisterState::Update
};
state.step_toward(next)
}
Self::Ir(state) => {
let next = if let Self::Ir(target) = target {
target
} else {
RegisterState::Update
};
state.step_toward(next)
}
};
Some(tms)
}
pub fn update(&mut self, tms: bool) {
*self = match *self {
Self::Reset if tms => Self::Reset,
Self::Reset => Self::Idle,
Self::Idle if tms => Self::Dr(RegisterState::Select),
Self::Idle => Self::Idle,
Self::Dr(RegisterState::Select) if tms => Self::Ir(RegisterState::Select),
Self::Ir(RegisterState::Select) if tms => Self::Reset,
Self::Dr(RegisterState::Update) | Self::Ir(RegisterState::Update) => {
if tms {
Self::Dr(RegisterState::Select)
} else {
Self::Idle
}
}
Self::Dr(state) => Self::Dr(state.update(tms)),
Self::Ir(state) => Self::Ir(state.update(tms)),
};
}
}
fn jtag_move_to_state(
protocol: &mut impl RawJtagIo,
target: JtagState,
) -> Result<(), DebugProbeError> {
tracing::trace!(
"Changing state: {:?} -> {:?}",
protocol.state_mut().state,
target
);
while let Some(tms) = protocol.state().state.step_toward(target) {
protocol.shift_bit(tms, false, false)?;
}
tracing::trace!("In state: {:?}", protocol.state_mut().state);
Ok(())
}
fn shift_ir(
protocol: &mut impl RawJtagIo,
data: &[u8],
len: usize,
capture_data: bool,
) -> Result<(), DebugProbeError> {
tracing::debug!("Write IR: {:?}, len={}", data, len);
if data.len() * 8 < len || len == 0 {
return Err(DebugProbeError::Other(format!(
"Invalid data length. IR bits: {}, expected: {}",
data.len(),
len
)));
}
let pre_bits = protocol.state().chain_params.irpre;
let post_bits = protocol.state().chain_params.irpost;
let tms_data = std::iter::repeat_n(false, len - 1);
jtag_move_to_state(protocol, JtagState::Ir(RegisterState::Shift))?;
let tms = std::iter::repeat_n(false, pre_bits)
.chain(tms_data)
.chain(std::iter::repeat_n(false, post_bits))
.chain(iter::once(true));
let tdi = std::iter::repeat_n(true, pre_bits)
.chain(data.as_bits::<Lsb0>()[..len].iter().map(|b| *b))
.chain(std::iter::repeat_n(true, post_bits));
let capture = std::iter::repeat_n(false, pre_bits)
.chain(std::iter::repeat_n(capture_data, len))
.chain(iter::repeat(false));
tracing::trace!("tms: {:?}", tms.clone());
tracing::trace!("tdi: {:?}", tdi.clone());
protocol.shift_bits(tms, tdi, capture)?;
jtag_move_to_state(protocol, JtagState::Ir(RegisterState::Update))?;
Ok(())
}
fn shift_dr(
protocol: &mut impl RawJtagIo,
data: &[u8],
register_bits: usize,
capture_data: bool,
) -> Result<usize, DebugProbeError> {
tracing::debug!("Write DR: {:?}, len={}", data, register_bits);
if data.len() * 8 < register_bits || register_bits == 0 {
return Err(DebugProbeError::Other(format!(
"Invalid data length. DR bits: {}, expected: {}",
data.len(),
register_bits
)));
}
let tms_shift_out_value = std::iter::repeat_n(false, register_bits - 1);
jtag_move_to_state(protocol, JtagState::Dr(RegisterState::Shift))?;
let pre_bits = protocol.state().chain_params.drpre;
let post_bits = protocol.state().chain_params.drpost;
let tms = std::iter::repeat_n(false, pre_bits)
.chain(tms_shift_out_value)
.chain(std::iter::repeat_n(false, post_bits))
.chain(iter::once(true));
let tdi = std::iter::repeat_n(false, pre_bits)
.chain(data.as_bits::<Lsb0>()[..register_bits].iter().map(|b| *b))
.chain(std::iter::repeat_n(false, post_bits));
let capture = std::iter::repeat_n(false, pre_bits)
.chain(std::iter::repeat_n(capture_data, register_bits))
.chain(iter::repeat(false));
protocol.shift_bits(tms, tdi, capture)?;
jtag_move_to_state(protocol, JtagState::Dr(RegisterState::Update))?;
let idle_cycles = protocol.state().jtag_idle_cycles;
if idle_cycles > 0 {
jtag_move_to_state(protocol, JtagState::Idle)?;
let tms = std::iter::repeat_n(false, idle_cycles);
let tdi = std::iter::repeat_n(false, idle_cycles);
protocol.shift_bits(tms, tdi, iter::repeat(false))?;
}
if capture_data {
Ok(register_bits)
} else {
Ok(0)
}
}
fn prepare_write_register(
protocol: &mut impl RawJtagIo,
address: u32,
data: &[u8],
len: u32,
capture: bool,
) -> Result<usize, DebugProbeError> {
if address > protocol.state().max_ir_address() {
return Err(DebugProbeError::Other(format!(
"Invalid instruction register access: {address}"
)));
}
let ir_len = protocol.state().chain_params.irlen;
shift_ir(protocol, &address.to_le_bytes(), ir_len, false)?;
shift_dr(protocol, data, len as usize, capture)
}
impl<Probe: AutoImplementJtagAccess> JtagAccess for Probe {
fn shift_raw_sequence(&mut self, sequence: JtagSequence) -> Result<BitVec, DebugProbeError> {
self.shift_bits(
std::iter::repeat(sequence.tms),
sequence.data.into_iter(),
std::iter::repeat(sequence.tdo_capture),
)?;
self.read_captured_bits()
}
fn set_scan_chain(&mut self, scan_chain: &[ScanChainElement]) -> Result<(), DebugProbeError> {
self.state_mut().expected_scan_chain = Some(scan_chain.to_vec());
Ok(())
}
fn select_target(&mut self, target: usize) -> Result<(), DebugProbeError> {
if self.state().scan_chain.is_empty() {
self.scan_chain()?;
}
let state = self.state_mut();
let Some(params) = ChainParams::from_jtag_chain(&state.scan_chain, target) else {
return Err(DebugProbeError::TargetNotFound);
};
tracing::debug!("Selecting JTAG TAP: {target}");
tracing::debug!("Setting chain params: {params:?}");
state.chain_params = params;
Ok(())
}
fn scan_chain(&mut self) -> Result<&[ScanChainElement], DebugProbeError> {
if !self.state().scan_chain.is_empty() {
return Ok(self.state().scan_chain.as_slice());
}
const MAX_CHAIN: usize = 8;
self.reset_jtag_state_machine()?;
self.state_mut().chain_params = ChainParams::default();
let input = [0xFF; 4 * MAX_CHAIN];
shift_dr(self, &input, input.len() * 8, true)?;
let response = self.read_captured_bits()?;
tracing::debug!("DR: {:?}", response);
let idcodes = extract_idcodes(&response)?;
tracing::info!(
"JTAG DR scan complete, found {} TAPs. {:?}",
idcodes.len(),
idcodes
);
tracing::debug!("Scanning JTAG chain for IR lengths");
let input = vec![0xff; idcodes.len()];
shift_ir(self, &input, input.len() * 8, true)?;
let response = self.read_captured_bits()?;
tracing::debug!("IR scan: {}", response);
self.reset_jtag_state_machine()?;
let input = std::iter::repeat_n(0, idcodes.len())
.chain(input.iter().copied())
.collect::<Vec<_>>();
shift_ir(self, &input, input.len() * 8, true)?;
let response_zeros = self.read_captured_bits()?;
tracing::debug!("IR scan: {}", response_zeros);
let response = response.as_bitslice();
let response = common_sequence(response, response_zeros.as_bitslice());
tracing::debug!("IR scan: {}", response);
let ir_lens = extract_ir_lengths(
response,
idcodes.len(),
self.state()
.expected_scan_chain
.as_ref()
.map(|chain| {
chain
.iter()
.filter_map(|s| s.ir_len)
.map(|s| s as usize)
.collect::<Vec<usize>>()
})
.as_deref(),
)?;
tracing::info!("Found {} TAPs on reset scan", idcodes.len());
tracing::debug!("Detected IR lens: {:?}", ir_lens);
let chain = idcodes
.into_iter()
.zip(ir_lens)
.map(|(idcode, irlen)| ScanChainElement {
ir_len: Some(irlen as u8),
name: idcode.map(|i| i.to_string()),
})
.collect::<Vec<_>>();
self.state_mut().scan_chain = chain;
Ok(self.state().scan_chain.as_slice())
}
fn tap_reset(&mut self) -> Result<(), DebugProbeError> {
self.reset_jtag_state_machine()
}
fn set_idle_cycles(&mut self, idle_cycles: u8) -> Result<(), DebugProbeError> {
self.state_mut().jtag_idle_cycles = idle_cycles as usize;
Ok(())
}
fn idle_cycles(&self) -> u8 {
self.state().jtag_idle_cycles as u8
}
fn read_register(&mut self, address: u32, len: u32) -> Result<BitVec, DebugProbeError> {
let data = vec![0u8; len.div_ceil(8) as usize];
self.write_register(address, &data, len)
}
fn write_register(
&mut self,
address: u32,
data: &[u8],
len: u32,
) -> Result<BitVec, DebugProbeError> {
prepare_write_register(self, address, data, len, true)?;
let response = self.read_captured_bits()?;
tracing::trace!("recieve_write_dr result: {:?}", response);
Ok(response)
}
fn write_dr(&mut self, data: &[u8], len: u32) -> Result<BitVec, DebugProbeError> {
shift_dr(self, data, len as usize, true)?;
let response = self.read_captured_bits()?;
tracing::trace!("write_dr result: {:?}", response);
Ok(response)
}
#[tracing::instrument(skip(self, writes))]
fn write_register_batch(
&mut self,
writes: &CommandQueue<JtagCommand>,
) -> Result<DeferredResultSet<CommandResult>, BatchExecutionError> {
let mut bits = Vec::with_capacity(writes.len());
let t1 = std::time::Instant::now();
tracing::debug!("Preparing {} writes...", writes.len());
for (idx, command) in writes.iter() {
let result = match command {
JtagCommand::WriteRegister(write) => prepare_write_register(
self,
write.address,
&write.data,
write.len,
idx.should_capture(),
),
JtagCommand::ShiftDr(write) => {
shift_dr(self, &write.data, write.len as usize, idx.should_capture())
}
};
let op =
result.map_err(|e| BatchExecutionError::new(e.into(), DeferredResultSet::new()))?;
bits.push((idx, command, op));
}
tracing::debug!("Sending to chip...");
let bitstream = self
.read_captured_bits()
.map_err(|e| BatchExecutionError::new(e.into(), DeferredResultSet::new()))?;
tracing::debug!("Got responses! Took {:?}! Processing...", t1.elapsed());
let mut responses = DeferredResultSet::with_capacity(bits.len());
let mut bitstream = bitstream.as_bitslice();
for (idx, command, bits) in bits.into_iter() {
if idx.should_capture() {
let response = &bitstream[..bits];
let result = match command {
JtagCommand::WriteRegister(command) => (command.transform)(command, response),
JtagCommand::ShiftDr(command) => (command.transform)(command, response),
};
match result {
Ok(response) => responses.push(idx, response),
Err(e) => return Err(BatchExecutionError::new(e, responses)),
}
} else {
responses.push(idx, CommandResult::None);
}
bitstream = &bitstream[bits..];
}
Ok(responses)
}
}
#[cfg(test)]
mod tests {
use super::*;
const ARM_TAP: IdCode = IdCode(0x4BA00477);
const STM_BS_TAP: IdCode = IdCode(0x06433041);
#[test]
fn id_code_display() {
let debug_fmt = format!("{ARM_TAP}");
assert_eq!(debug_fmt, "0x4BA00477 (ARM Ltd)");
let debug_fmt = format!("{STM_BS_TAP}");
assert_eq!(debug_fmt, "0x06433041 (STMicroelectronics)");
}
#[test]
fn extract_ir_lengths_with_one_tap() {
let ir = bits![1, 0, 0, 0];
let n_taps = 1;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4]);
}
#[test]
fn extract_ir_lengths_with_two_taps() {
let ir = bits![1, 0, 0, 0, 1, 0, 0, 0, 0];
let n_taps = 2;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4, 5]);
}
#[test]
fn extract_ir_lengths_with_two_taps_101() {
let ir = bits![1, 0, 1, 0, 1, 0, 0, 0, 0];
let n_taps = 2;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4, 5]);
}
#[test]
fn extract_id_codes_one_tap() {
let dr = bits![mut 0; 32];
dr[0..32].store_le(ARM_TAP.0);
let idcodes = extract_idcodes(dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP)]);
}
#[test]
fn extract_id_codes_two_taps() {
let dr = bits![mut 0; 64];
dr[0..32].store_le(ARM_TAP.0);
dr[32..64].store_le(STM_BS_TAP.0);
let idcodes = extract_idcodes(dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP), Some(STM_BS_TAP)]);
}
#[test]
fn extract_id_codes_tap_bypass_tap() {
let dr = bits![mut 0; 65];
dr[0..32].store_le(ARM_TAP.0);
dr.set(32, false);
dr[33..65].store_le(STM_BS_TAP.0);
let idcodes = extract_idcodes(dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP), None, Some(STM_BS_TAP)]);
}
#[test]
fn reset_from_ir_shift() {
let mut state = JtagState::Ir(RegisterState::Shift);
state.update(true);
state.update(true);
state.update(true);
state.update(true);
state.update(true);
assert_eq!(state, JtagState::Reset);
}
#[test]
fn idle_from_reset() {
let mut state = JtagState::Reset;
state.update(false);
assert_eq!(state, JtagState::Idle);
}
#[test]
fn generated_bits_lead_to_correct_state() {
for (start, goal) in [(JtagState::Reset, JtagState::Idle)] {
let mut state = start;
let mut transitions = 0;
while state != goal && transitions < 10 {
let tms = state.step_toward(goal).unwrap();
state.update(tms);
transitions += 1;
}
assert!(transitions < 10);
}
}
}