1use std::vec;
23
24use amplify::confinement::Confined;
25use amplify::{confinement, Bytes32StrRev, Wrapper};
26
27use crate::opcodes::*;
28use crate::{
29    ByteStr, RedeemScript, ScriptBytes, ScriptPubkey, VarIntArray, WScriptHash, LIB_NAME_BITCOIN,
30};
31
32#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Display, Error)]
33#[display(doc_comments)]
34pub enum SegwitError {
35    InvalidWitnessVersion(u8),
37    MalformedWitnessVersion,
40    InvalidWitnessProgramLength(usize),
42    InvalidSegwitV0ProgramLength(usize),
44    UncompressedPubkey,
46}
47
48#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Display)]
53#[derive(StrictType, StrictEncode, StrictDecode, StrictDumb)]
54#[strict_type(lib = LIB_NAME_BITCOIN, tags = repr, into_u8, try_from_u8)]
55#[repr(u8)]
56pub enum WitnessVer {
57    #[strict_type(dumb)]
59    #[display("segwit0")]
60    V0 = OP_PUSHBYTES_0,
61
62    #[display("segwit1")]
64    V1 = OP_PUSHNUM_1,
65
66    #[display("segwit2")]
68    V2 = OP_PUSHNUM_2,
69
70    #[display("segwit3")]
72    V3 = OP_PUSHNUM_3,
73
74    #[display("segwit4")]
76    V4 = OP_PUSHNUM_4,
77
78    #[display("segwit5")]
80    V5 = OP_PUSHNUM_5,
81
82    #[display("segwit6")]
84    V6 = OP_PUSHNUM_6,
85
86    #[display("segwit7")]
88    V7 = OP_PUSHNUM_7,
89
90    #[display("segwit8")]
92    V8 = OP_PUSHNUM_8,
93
94    #[display("segwit9")]
96    V9 = OP_PUSHNUM_9,
97
98    #[display("segwit10")]
100    V10 = OP_PUSHNUM_10,
101
102    #[display("segwit11")]
104    V11 = OP_PUSHNUM_11,
105
106    #[display("segwit12")]
108    V12 = OP_PUSHNUM_12,
109
110    #[display("segwit13")]
112    V13 = OP_PUSHNUM_13,
113
114    #[display("segwit14")]
116    V14 = OP_PUSHNUM_14,
117
118    #[display("segwit15")]
120    V15 = OP_PUSHNUM_15,
121
122    #[display("segwit16")]
124    V16 = OP_PUSHNUM_16,
125}
126
127impl WitnessVer {
128    pub fn from_op_code(op_code: OpCode) -> Result<Self, SegwitError> {
134        match op_code as u8 {
135            0 => Ok(WitnessVer::V0),
136            OP_PUSHNUM_1 => Ok(WitnessVer::V1),
137            OP_PUSHNUM_2 => Ok(WitnessVer::V2),
138            OP_PUSHNUM_3 => Ok(WitnessVer::V3),
139            OP_PUSHNUM_4 => Ok(WitnessVer::V4),
140            OP_PUSHNUM_5 => Ok(WitnessVer::V5),
141            OP_PUSHNUM_6 => Ok(WitnessVer::V6),
142            OP_PUSHNUM_7 => Ok(WitnessVer::V7),
143            OP_PUSHNUM_8 => Ok(WitnessVer::V8),
144            OP_PUSHNUM_9 => Ok(WitnessVer::V9),
145            OP_PUSHNUM_10 => Ok(WitnessVer::V10),
146            OP_PUSHNUM_11 => Ok(WitnessVer::V11),
147            OP_PUSHNUM_12 => Ok(WitnessVer::V12),
148            OP_PUSHNUM_13 => Ok(WitnessVer::V13),
149            OP_PUSHNUM_14 => Ok(WitnessVer::V14),
150            OP_PUSHNUM_15 => Ok(WitnessVer::V15),
151            OP_PUSHNUM_16 => Ok(WitnessVer::V16),
152            _ => Err(SegwitError::MalformedWitnessVersion),
153        }
154    }
155
156    pub fn from_version_no(no: u8) -> Result<Self, SegwitError> {
162        Ok(match no {
163            v if v == Self::V0.version_no() => Self::V0,
164            v if v == Self::V1.version_no() => Self::V1,
165            v if v == Self::V2.version_no() => Self::V2,
166            v if v == Self::V3.version_no() => Self::V3,
167            v if v == Self::V4.version_no() => Self::V4,
168            v if v == Self::V5.version_no() => Self::V5,
169            v if v == Self::V6.version_no() => Self::V6,
170            v if v == Self::V7.version_no() => Self::V7,
171            v if v == Self::V8.version_no() => Self::V8,
172            v if v == Self::V9.version_no() => Self::V9,
173            v if v == Self::V10.version_no() => Self::V10,
174            v if v == Self::V11.version_no() => Self::V11,
175            v if v == Self::V12.version_no() => Self::V12,
176            v if v == Self::V13.version_no() => Self::V13,
177            v if v == Self::V14.version_no() => Self::V14,
178            v if v == Self::V15.version_no() => Self::V15,
179            v if v == Self::V16.version_no() => Self::V16,
180            _ => return Err(SegwitError::InvalidWitnessVersion(no)),
181        })
182    }
183
184    pub fn op_code(self) -> OpCode {
188        OpCode::try_from(self as u8).expect("full range of u8 is covered")
189    }
190
191    pub fn version_no(self) -> u8 {
193        match self {
194            WitnessVer::V0 => 0,
195            WitnessVer::V1 => 1,
196            WitnessVer::V2 => 2,
197            WitnessVer::V3 => 3,
198            WitnessVer::V4 => 4,
199            WitnessVer::V5 => 5,
200            WitnessVer::V6 => 6,
201            WitnessVer::V7 => 7,
202            WitnessVer::V8 => 8,
203            WitnessVer::V9 => 9,
204            WitnessVer::V10 => 10,
205            WitnessVer::V11 => 11,
206            WitnessVer::V12 => 12,
207            WitnessVer::V13 => 13,
208            WitnessVer::V14 => 14,
209            WitnessVer::V15 => 15,
210            WitnessVer::V16 => 16,
211        }
212    }
213}
214
215#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
217#[derive(StrictType, StrictEncode, StrictDecode, StrictDumb)]
218#[strict_type(lib = LIB_NAME_BITCOIN, dumb = Self::dumb())]
219pub struct WitnessProgram {
220    version: WitnessVer,
222    program: Confined<Vec<u8>, 2, 40>,
224}
225
226impl WitnessProgram {
227    fn dumb() -> Self { Self::new(strict_dumb!(), vec![0; 32]).unwrap() }
228
229    pub fn new(version: WitnessVer, program: Vec<u8>) -> Result<Self, SegwitError> {
231        let len = program.len();
232        let program = Confined::try_from(program)
233            .map_err(|_| SegwitError::InvalidWitnessProgramLength(len))?;
234
235        if version == WitnessVer::V0 && (program.len() != 20 && program.len() != 32) {
238            return Err(SegwitError::InvalidSegwitV0ProgramLength(program.len()));
239        }
240
241        Ok(WitnessProgram { version, program })
242    }
243
244    pub fn version(&self) -> WitnessVer { self.version }
246
247    pub fn program(&self) -> &[u8] { &self.program }
249}
250
251impl ScriptPubkey {
252    pub fn p2wpkh(hash: impl Into<[u8; 20]>) -> Self {
253        Self::with_witness_program_unchecked(WitnessVer::V0, &hash.into())
254    }
255
256    pub fn p2wsh(hash: impl Into<[u8; 32]>) -> Self {
257        Self::with_witness_program_unchecked(WitnessVer::V0, &hash.into())
258    }
259
260    pub fn is_p2wpkh(&self) -> bool {
261        self.len() == 22 && self[0] == WitnessVer::V0.op_code() as u8 && self[1] == OP_PUSHBYTES_20
262    }
263
264    pub fn is_p2wsh(&self) -> bool {
265        self.len() == 34 && self[0] == WitnessVer::V0.op_code() as u8 && self[1] == OP_PUSHBYTES_32
266    }
267
268    pub fn from_witness_program(witness_program: &WitnessProgram) -> Self {
270        Self::with_witness_program_unchecked(witness_program.version, witness_program.program())
271    }
272
273    pub(crate) fn with_witness_program_unchecked(ver: WitnessVer, prog: &[u8]) -> Self {
276        let mut script = Self::with_capacity(ScriptBytes::len_for_slice(prog.len()) + 2);
277        script.push_opcode(ver.op_code());
278        script.push_slice(prog);
279        script
280    }
281
282    #[inline]
284    pub fn is_witness_program(&self) -> bool {
285        let script_len = self.len();
291        if !(4..=42).contains(&script_len) {
292            return false;
293        }
294        let Ok(ver_opcode) = OpCode::try_from(self[0]) else {
296            return false;
297        };
298        let push_opbyte = self[1]; WitnessVer::from_op_code(ver_opcode).is_ok()
300            && (OP_PUSHBYTES_2..=OP_PUSHBYTES_40).contains(&push_opbyte)
301            && script_len - 2 == push_opbyte as usize
303    }
304}
305
306#[derive(Wrapper, WrapperMut, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, From, Default)]
307#[wrapper(Deref, AsSlice, Hex)]
308#[wrapper_mut(DerefMut, AsSliceMut)]
309#[derive(StrictType, StrictEncode, StrictDecode)]
310#[strict_type(lib = LIB_NAME_BITCOIN)]
311#[cfg_attr(
312    feature = "serde",
313    derive(Serialize, Deserialize),
314    serde(crate = "serde_crate", transparent)
315)]
316pub struct WitnessScript(ScriptBytes);
317
318impl TryFrom<Vec<u8>> for WitnessScript {
319    type Error = confinement::Error;
320    fn try_from(script_bytes: Vec<u8>) -> Result<Self, Self::Error> {
321        ScriptBytes::try_from(script_bytes).map(Self)
322    }
323}
324
325impl WitnessScript {
326    #[inline]
327    pub fn new() -> Self { Self::default() }
328
329    #[inline]
330    pub fn with_capacity(capacity: usize) -> Self {
331        Self(ScriptBytes::from(Confined::with_capacity(capacity)))
332    }
333
334    #[inline]
337    pub fn from_unsafe(script_bytes: Vec<u8>) -> Self {
338        Self(ScriptBytes::from_unsafe(script_bytes))
339    }
340
341    #[inline]
343    pub fn push_opcode(&mut self, op_code: OpCode) { self.0.push(op_code as u8); }
344
345    pub fn to_redeem_script(&self) -> RedeemScript {
346        let script = ScriptPubkey::p2wsh(WScriptHash::from(self));
347        RedeemScript::from_inner(script.into_inner())
348    }
349
350    pub fn to_script_pubkey(&self) -> ScriptPubkey { ScriptPubkey::p2wsh(WScriptHash::from(self)) }
351
352    #[inline]
353    pub fn as_script_bytes(&self) -> &ScriptBytes { &self.0 }
354}
355
356#[derive(Wrapper, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, From)]
357#[wrapper(BorrowSlice, Index, RangeOps, Debug, Hex, Display, FromStr)]
358#[derive(StrictType, StrictDumb, StrictEncode, StrictDecode)]
359#[strict_type(lib = LIB_NAME_BITCOIN)]
360#[cfg_attr(
361    feature = "serde",
362    derive(Serialize, Deserialize),
363    serde(crate = "serde_crate", transparent)
364)]
365pub struct Wtxid(
366    #[from]
367    #[from([u8; 32])]
368    Bytes32StrRev,
369);
370
371#[derive(Wrapper, Clone, Eq, PartialEq, Hash, Debug, From, Default)]
372#[wrapper(Deref, Index, RangeOps)]
373#[derive(StrictType, StrictEncode, StrictDecode)]
374#[strict_type(lib = LIB_NAME_BITCOIN)]
375pub struct Witness(VarIntArray<ByteStr>);
376
377impl IntoIterator for Witness {
378    type Item = ByteStr;
379    type IntoIter = vec::IntoIter<ByteStr>;
380
381    fn into_iter(self) -> Self::IntoIter { self.0.into_iter() }
382}
383
384impl Witness {
385    #[inline]
386    pub fn new() -> Self { default!() }
387
388    #[inline]
389    pub fn elements(&self) -> impl Iterator<Item = &'_ [u8]> {
390        self.0.iter().map(|el| el.as_slice())
391    }
392
393    pub fn from_consensus_stack(witness: impl IntoIterator<Item = Vec<u8>>) -> Witness {
394        let iter = witness.into_iter().map(ByteStr::from);
395        let stack =
396            VarIntArray::try_from_iter(iter).expect("witness stack size exceeds 2^32 elements");
397        Witness(stack)
398    }
399
400    #[inline]
401    pub(crate) fn as_var_int_array(&self) -> &VarIntArray<ByteStr> { &self.0 }
402}
403
404#[cfg(feature = "serde")]
405mod _serde {
406    use serde::{Deserialize, Serialize};
407    use serde_crate::ser::SerializeSeq;
408    use serde_crate::{Deserializer, Serializer};
409
410    use super::*;
411
412    impl Serialize for Witness {
413        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
414        where S: Serializer {
415            let mut ser = serializer.serialize_seq(Some(self.len()))?;
416            for el in &self.0 {
417                ser.serialize_element(&el)?;
418            }
419            ser.end()
420        }
421    }
422
423    impl<'de> Deserialize<'de> for Witness {
424        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
425        where D: Deserializer<'de> {
426            let data = Vec::<ByteStr>::deserialize(deserializer)?;
427            Ok(Witness::from_consensus_stack(data.into_iter().map(ByteStr::into_vec)))
428        }
429    }
430}